You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/03/11 12:49:34 UTC
[airflow] branch main updated: Add oss_task_handler into alibaba-provider and enable remote logging to OSS (#21785)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7bd8b2d Add oss_task_handler into alibaba-provider and enable remote logging to OSS (#21785)
7bd8b2d is described below
commit 7bd8b2d7f3bca39a919cf0aeef91da1c476d792d
Author: Eric Gao <er...@gmail.com>
AuthorDate: Fri Mar 11 20:48:40 2022 +0800
Add oss_task_handler into alibaba-provider and enable remote logging to OSS (#21785)
---
airflow/config_templates/airflow_local_settings.py | 11 ++
airflow/providers/alibaba/cloud/hooks/oss.py | 96 ++++++++++-
.../providers/alibaba/cloud/log/__init__.py | 20 ---
.../alibaba/cloud/log/oss_task_handler.py | 186 +++++++++++++++++++++
airflow/providers/alibaba/provider.yaml | 3 +
airflow/utils/db.py | 3 +-
docs/apache-airflow-providers-alibaba/index.rst | 1 +
.../logging/index.rst | 25 +++
.../logging/oss-task-handler.rst | 42 +++++
tests/providers/alibaba/cloud/hooks/test_oss.py | 52 +++++-
.../cloud/{utils/oss_mock.py => log/__init__.py} | 20 ---
.../alibaba/cloud/log/test_oss_task_handler.py | 135 +++++++++++++++
tests/providers/alibaba/cloud/utils/oss_mock.py | 1 +
13 files changed, 548 insertions(+), 47 deletions(-)
diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py
index b6a9bdb..14fa529 100644
--- a/airflow/config_templates/airflow_local_settings.py
+++ b/airflow/config_templates/airflow_local_settings.py
@@ -248,6 +248,17 @@ if REMOTE_LOGGING:
}
DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith('oss://'):
+ OSS_REMOTE_HANDLERS = {
+ 'task': {
+ 'class': 'airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler',
+ 'formatter': 'airflow',
+ 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
+ 'oss_log_folder': REMOTE_BASE_LOG_FOLDER,
+ 'filename_template': FILENAME_TEMPLATE,
+ },
+ }
+ DEFAULT_LOGGING_CONFIG['handlers'].update(OSS_REMOTE_HANDLERS)
elif ELASTICSEARCH_HOST:
ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE')
ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK')
diff --git a/airflow/providers/alibaba/cloud/hooks/oss.py b/airflow/providers/alibaba/cloud/hooks/oss.py
index e50b89c..23bbde1 100644
--- a/airflow/providers/alibaba/cloud/hooks/oss.py
+++ b/airflow/providers/alibaba/cloud/hooks/oss.py
@@ -89,10 +89,13 @@ class OSSHook(BaseHook):
conn_type = 'oss'
hook_name = 'OSS'
- def __init__(self, region, oss_conn_id='oss_default', *args, **kwargs) -> None:
+ def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
- self.region = region
+ if region is None:
+ self.region = self.get_default_region()
+ else:
+ self.region = region
super().__init__(*args, **kwargs)
def get_conn(self) -> "Connection":
@@ -117,6 +120,7 @@ class OSSHook(BaseHook):
return bucket_name, key
+ @provide_bucket_name
@unify_bucket_name_and_key
def object_exists(self, key: str, bucket_name: Optional[str] = None) -> bool:
"""
@@ -143,8 +147,10 @@ class OSSHook(BaseHook):
:rtype: oss2.api.Bucket
"""
auth = self.get_credential()
+ assert self.region is not None
return oss2.Bucket(auth, 'http://oss-' + self.region + '.aliyuncs.com', bucket_name)
+ @provide_bucket_name
@unify_bucket_name_and_key
def load_string(self, key: str, content: str, bucket_name: Optional[str] = None) -> None:
"""
@@ -159,6 +165,7 @@ class OSSHook(BaseHook):
except Exception as e:
raise AirflowException(f"Errors: {e}")
+ @provide_bucket_name
@unify_bucket_name_and_key
def upload_local_file(
self,
@@ -178,6 +185,7 @@ class OSSHook(BaseHook):
except Exception as e:
raise AirflowException(f"Errors when upload file: {e}")
+ @provide_bucket_name
@unify_bucket_name_and_key
def download_file(
self,
@@ -201,6 +209,7 @@ class OSSHook(BaseHook):
return None
return local_file
+ @provide_bucket_name
@unify_bucket_name_and_key
def delete_object(
self,
@@ -219,6 +228,7 @@ class OSSHook(BaseHook):
self.log.error(e)
raise AirflowException(f"Errors when deleting: {key}")
+ @provide_bucket_name
@unify_bucket_name_and_key
def delete_objects(
self,
@@ -269,6 +279,73 @@ class OSSHook(BaseHook):
self.log.error(e)
raise AirflowException(f"Errors when create bucket: {bucket_name}")
+ @provide_bucket_name
+ @unify_bucket_name_and_key
+ def append_string(self, bucket_name: Optional[str], content: str, key: str, pos: int) -> None:
+ """
+ Append string to a remote existing file
+
+ :param bucket_name: the name of the bucket
+ :param content: content to be appended
+ :param key: oss bucket key
+ :param pos: position of the existing file where the content will be appended
+ """
+ self.log.info("Write oss bucket. key: %s, pos: %s", key, pos)
+ try:
+ self.get_bucket(bucket_name).append_object(key, pos, content)
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when append string for object: {key}")
+
+ @provide_bucket_name
+ @unify_bucket_name_and_key
+ def read_key(self, bucket_name: Optional[str], key: str) -> str:
+ """
+ Read oss remote object content with the specified key
+
+ :param bucket_name: the name of the bucket
+ :param key: oss bucket key
+ """
+ self.log.info("Read oss key: %s", key)
+ try:
+ return self.get_bucket(bucket_name).get_object(key).read().decode("utf-8")
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when read bucket object: {key}")
+
+ @provide_bucket_name
+ @unify_bucket_name_and_key
+ def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObjectResult:
+ """
+ Get meta info of the specified remote object
+
+ :param bucket_name: the name of the bucket
+ :param key: oss bucket key
+ """
+ self.log.info("Head Object oss key: %s", key)
+ try:
+ return self.get_bucket(bucket_name).head_object(key)
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when head bucket object: {key}")
+
+ @provide_bucket_name
+ @unify_bucket_name_and_key
+ def key_exist(self, bucket_name: Optional[str], key: str) -> bool:
+ """
+ Find out whether the specified key exists in the oss remote storage
+
+ :param bucket_name: the name of the bucket
+ :param key: oss bucket key
+ """
+ # full_path = None
+ self.log.info('Looking up oss bucket %s for bucket key %s ...', bucket_name, key)
+ try:
+ return self.get_bucket(bucket_name).object_exists(key)
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when check bucket object existence: {key}")
+
def get_credential(self) -> oss2.auth.Auth:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get('auth_type', None)
@@ -285,3 +362,18 @@ class OSSHook(BaseHook):
return oss2.Auth(oss_access_key_id, oss_access_key_secret)
else:
raise Exception("Unsupported auth_type: " + auth_type)
+
+ def get_default_region(self) -> Optional[str]:
+ extra_config = self.oss_conn.extra_dejson
+ auth_type = extra_config.get('auth_type', None)
+ if not auth_type:
+ raise Exception("No auth_type specified in extra_config. ")
+
+ if auth_type == 'AK':
+ default_region = extra_config.get('region', None)
+ if not default_region:
+ raise Exception("No region is specified for connection: " + self.oss_conn_id)
+ else:
+ raise Exception("Unsupported auth_type: " + auth_type)
+
+ return default_region
diff --git a/tests/providers/alibaba/cloud/utils/oss_mock.py b/airflow/providers/alibaba/cloud/log/__init__.py
similarity index 60%
copy from tests/providers/alibaba/cloud/utils/oss_mock.py
copy to airflow/providers/alibaba/cloud/log/__init__.py
index 7bbef54..13a8339 100644
--- a/tests/providers/alibaba/cloud/utils/oss_mock.py
+++ b/airflow/providers/alibaba/cloud/log/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,22 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
-
-from airflow.models import Connection
-
-OSS_PROJECT_ID_HOOK_UNIT_TEST = 'example-project'
-
-
-def mock_oss_hook_default_project_id(self, oss_conn_id='mock_oss_default', region='mock_region'):
- self.oss_conn_id = oss_conn_id
- self.oss_conn = Connection(
- extra=json.dumps(
- {
- 'auth_type': 'AK',
- 'access_key_id': 'mock_access_key_id',
- 'access_key_secret': 'mock_access_key_secret',
- }
- )
- )
- self.region = region
diff --git a/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/airflow/providers/alibaba/cloud/log/oss_task_handler.py
new file mode 100644
index 0000000..d26bfbf
--- /dev/null
+++ b/airflow/providers/alibaba/cloud/log/oss_task_handler.py
@@ -0,0 +1,186 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import sys
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
+
+from airflow.configuration import conf
+from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
+from airflow.utils.log.file_task_handler import FileTaskHandler
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class OSSTaskHandler(FileTaskHandler, LoggingMixin):
+ """
+ OSSTaskHandler is a python log handler that handles and reads
+ task instance logs. It extends airflow FileTaskHandler and
+ uploads to and reads from OSS remote storage.
+ """
+
+ def __init__(self, base_log_folder, oss_log_folder, filename_template):
+ self.log.info("Using oss_task_handler for remote logging...")
+ super().__init__(base_log_folder, filename_template)
+ (self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder)
+ self.log_relative_path = ''
+ self._hook = None
+ self.closed = False
+ self.upload_on_close = True
+
+ @cached_property
+ def hook(self):
+ remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
+ self.log.info("remote_conn_id: %s", remote_conn_id)
+ try:
+ return OSSHook(oss_conn_id=remote_conn_id)
+ except Exception as e:
+ self.log.error(e, exc_info=True)
+ self.log.error(
+ 'Could not create an OSSHook with connection id "%s". '
+ 'Please make sure that airflow[oss] is installed and '
+ 'the OSS connection exists.',
+ remote_conn_id,
+ )
+
+ def set_context(self, ti):
+ super().set_context(ti)
+ # Local location and remote location is needed to open and
+ # upload local log file to OSS remote storage.
+ self.log_relative_path = self._render_filename(ti, ti.try_number)
+ self.upload_on_close = not ti.raw
+
+ # Clear the file first so that duplicate data is not uploaded
+ # when re-using the same path (e.g. with rescheduled sensors)
+ if self.upload_on_close:
+ with open(self.handler.baseFilename, 'w'):
+ pass
+
+ def close(self):
+ """Close and upload local log file to remote storage OSS."""
+ # When application exit, system shuts down all handlers by
+ # calling close method. Here we check if logger is already
+ # closed to prevent uploading the log to remote storage multiple
+ # times when `logging.shutdown` is called.
+ if self.closed:
+ return
+
+ super().close()
+
+ if not self.upload_on_close:
+ return
+
+ local_loc = os.path.join(self.local_base, self.log_relative_path)
+ remote_loc = self.log_relative_path
+ if os.path.exists(local_loc):
+ # read log and remove old logs to get just the latest additions
+ with open(local_loc) as logfile:
+ log = logfile.read()
+ self.oss_write(log, remote_loc)
+
+ # Mark closed so we don't double write if close is called twice
+ self.closed = True
+
+ def _read(self, ti, try_number, metadata=None):
+ """
+ Read logs of given task instance and try_number from OSS remote storage.
+ If failed, read the log from task instance host machine.
+
+ :param ti: task instance object
+ :param try_number: task instance try_number to read logs from
+ :param metadata: log metadata,
+ can be used for steaming log reading and auto-tailing.
+ """
+ # Explicitly getting log relative path is necessary as the given
+ # task instance might be different than task instance passed in
+ # in set_context method.
+ log_relative_path = self._render_filename(ti, try_number)
+ remote_loc = log_relative_path
+
+ if self.oss_log_exists(remote_loc):
+ # If OSS remote file exists, we do not fetch logs from task instance
+ # local machine even if there are errors reading remote logs, as
+ # returned remote_log will contain error messages.
+ remote_log = self.oss_read(remote_loc, return_error=True)
+ log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
+ return log, {'end_of_log': True}
+ else:
+ return super()._read(ti, try_number)
+
+ def oss_log_exists(self, remote_log_location):
+ """
+ Check if remote_log_location exists in remote storage
+
+ :param remote_log_location: log's location in remote storage
+ :return: True if location exists else False
+ """
+ oss_remote_log_location = self.base_folder + '/' + remote_log_location
+ try:
+ return self.hook.key_exist(self.bucket_name, oss_remote_log_location)
+ except Exception:
+ pass
+ return False
+
+ def oss_read(self, remote_log_location, return_error=False):
+ """
+ Returns the log found at the remote_log_location. Returns '' if no
+ logs are found or there is an error.
+
+ :param remote_log_location: the log's location in remote storage
+ :param return_error: if True, returns a string error message if an
+ error occurs. Otherwise returns '' when an error occurs.
+ """
+ try:
+ oss_remote_log_location = self.base_folder + '/' + remote_log_location
+ self.log.info("read remote log: %s", oss_remote_log_location)
+ return self.hook.read_key(self.bucket_name, oss_remote_log_location)
+ except Exception:
+ msg = f'Could not read logs from {oss_remote_log_location}'
+ self.log.exception(msg)
+ # return error if needed
+ if return_error:
+ return msg
+
+ def oss_write(self, log, remote_log_location, append=True):
+ """
+ Writes the log to the remote_log_location. Fails silently if no hook
+ was created.
+
+ :param log: the log to write to the remote_log_location
+ :param remote_log_location: the log's location in remote storage
+ :param append: if False, any existing log file is overwritten. If True,
+ the new log is appended to any existing logs.
+ """
+ oss_remote_log_location = self.base_folder + '/' + remote_log_location
+ pos = 0
+ if append and self.oss_log_exists(oss_remote_log_location):
+ head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
+ pos = head.content_length
+ self.log.info("log write pos is: %s", str(pos))
+ try:
+ self.log.info("writing remote log: %s", oss_remote_log_location)
+ self.hook.append_string(self.bucket_name, log, oss_remote_log_location, pos)
+ except Exception:
+ self.log.exception(
+ 'Could not write logs to %s, log write pos is: %s, Append is %s',
+ oss_remote_log_location,
+ str(pos),
+ str(append),
+ )
diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml
index 9f0e8f0..de394c8 100644
--- a/airflow/providers/alibaba/provider.yaml
+++ b/airflow/providers/alibaba/provider.yaml
@@ -57,3 +57,6 @@ hook-class-names: # deprecated - to be removed after providers add dependency o
connection-types:
- hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook
connection-type: oss
+
+logging:
+ - airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 6dfd19c..ac95115 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -436,7 +436,8 @@ def create_default_connections(session: Session = NEW_SESSION):
extra='''{
"auth_type": "AK",
"access_key_id": "<ACCESS_KEY_ID>",
- "access_key_secret": "<ACCESS_KEY_SECRET>"}
+ "access_key_secret": "<ACCESS_KEY_SECRET>",
+ "region": "<YOUR_OSS_REGION>"}
''',
),
session,
diff --git a/docs/apache-airflow-providers-alibaba/index.rst b/docs/apache-airflow-providers-alibaba/index.rst
index a0e3798..979ade1 100644
--- a/docs/apache-airflow-providers-alibaba/index.rst
+++ b/docs/apache-airflow-providers-alibaba/index.rst
@@ -27,6 +27,7 @@ Content
Connection types <connections/alibaba>
Operators <operators/index>
+ Logging for Tasks <logging/index>
.. toctree::
:maxdepth: 1
diff --git a/docs/apache-airflow-providers-alibaba/logging/index.rst b/docs/apache-airflow-providers-alibaba/logging/index.rst
new file mode 100644
index 0000000..9681e2d
--- /dev/null
+++ b/docs/apache-airflow-providers-alibaba/logging/index.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Logging for Tasks
+=================
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ *
diff --git a/docs/apache-airflow-providers-alibaba/logging/oss-task-handler.rst b/docs/apache-airflow-providers-alibaba/logging/oss-task-handler.rst
new file mode 100644
index 0000000..9a7c872
--- /dev/null
+++ b/docs/apache-airflow-providers-alibaba/logging/oss-task-handler.rst
@@ -0,0 +1,42 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _write-logs-alibaba-oss:
+
+Writing logs to Alibaba OSS
+---------------------------
+
+Remote logging to Alibaba OSS uses an existing Airflow connection to read or write logs. If you
+don't have a connection properly setup, this process will fail.
+
+
+Enabling remote logging
+'''''''''''''''''''''''
+
+To enable this feature, ``airflow.cfg`` must be configured as follows:
+
+.. code-block:: ini
+
+ [logging]
+ # Airflow can store logs remotely in Alibaba OSS. Users must supply a remote
+ # location URL (starting with either 'oss://...') and an Airflow connection
+ # id that provides access to the storage location.
+ remote_logging = True
+ remote_base_log_folder = oss://my-bucket/path/to/logs
+ remote_log_conn_id = oss_default
+
+In the above example, Airflow will try to use ``OSSHook('oss_default')``.
diff --git a/tests/providers/alibaba/cloud/hooks/test_oss.py b/tests/providers/alibaba/cloud/hooks/test_oss.py
index fe60893..f5659bf 100644
--- a/tests/providers/alibaba/cloud/hooks/test_oss.py
+++ b/tests/providers/alibaba/cloud/hooks/test_oss.py
@@ -41,7 +41,6 @@ class TestOSSHook(unittest.TestCase):
def test_parse_oss_url(self):
parsed = self.hook.parse_oss_url(f"oss://{MOCK_BUCKET_NAME}/this/is/not/a-real-key.txt")
- print(parsed)
assert parsed == (MOCK_BUCKET_NAME, "this/is/not/a-real-key.txt"), "Incorrect parsing of the oss url"
def test_parse_oss_object_directory(self):
@@ -96,19 +95,19 @@ class TestOSSHook(unittest.TestCase):
def test_download_file(self, mock_service):
self.hook.download_file(MOCK_KEY, MOCK_FILE_PATH, MOCK_BUCKET_NAME)
mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
- mock_service.return_value.get_object_to_file(MOCK_KEY, MOCK_FILE_PATH)
+ mock_service.return_value.get_object_to_file.assert_called_once_with(MOCK_KEY, MOCK_FILE_PATH)
@mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
def test_delete_object(self, mock_service):
self.hook.delete_object(MOCK_KEY, MOCK_BUCKET_NAME)
mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
- mock_service.return_value.delete_object(MOCK_KEY)
+ mock_service.return_value.delete_object.assert_called_once_with(MOCK_KEY)
@mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
def test_delete_objects(self, mock_service):
self.hook.delete_objects(MOCK_KEYS, MOCK_BUCKET_NAME)
mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
- mock_service.return_value.batch_delete_objects(MOCK_KEYS)
+ mock_service.return_value.batch_delete_objects.assert_called_once_with(MOCK_KEYS)
@mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
def test_delete_bucket(self, mock_service):
@@ -121,3 +120,48 @@ class TestOSSHook(unittest.TestCase):
self.hook.create_bucket(MOCK_BUCKET_NAME)
mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
mock_service.return_value.create_bucket.assert_called_once_with()
+
+ @mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
+ def test_append_string(self, mock_service):
+ self.hook.append_string(MOCK_BUCKET_NAME, MOCK_CONTENT, MOCK_KEY, 0)
+ mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
+ mock_service.return_value.append_object.assert_called_once_with(MOCK_KEY, 0, MOCK_CONTENT)
+
+ @mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
+ def test_read_key(self, mock_service):
+ # Given
+ mock_service.return_value.get_object.return_value.read.return_value.decode.return_value = MOCK_CONTENT
+
+ # When
+ res = self.hook.read_key(MOCK_BUCKET_NAME, MOCK_KEY)
+
+ # Then
+ assert res == MOCK_CONTENT
+ mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
+ mock_service.return_value.get_object.assert_called_once_with(MOCK_KEY)
+ mock_service.return_value.get_object.return_value.read.assert_called_once_with()
+ mock_service.return_value.get_object.return_value.read.return_value.decode.assert_called_once_with(
+ 'utf-8'
+ )
+
+ @mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
+ def test_head_key(self, mock_service):
+ self.hook.head_key(MOCK_BUCKET_NAME, MOCK_KEY)
+ mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
+ mock_service.return_value.head_object.assert_called_once_with(MOCK_KEY)
+
+ @mock.patch(OSS_STRING.format('OSSHook.get_bucket'))
+ def test_key_exists(self, mock_service):
+ # When
+ mock_service.return_value.object_exists.return_value = True
+
+ # Given
+ res = self.hook.key_exist(MOCK_BUCKET_NAME, MOCK_KEY)
+
+ # Then
+ assert res is True
+ mock_service.assert_called_once_with(MOCK_BUCKET_NAME)
+ mock_service.return_value.object_exists.assert_called_once_with(MOCK_KEY)
+
+ def test_get_default_region(self):
+ assert self.hook.get_default_region() == 'mock_region'
diff --git a/tests/providers/alibaba/cloud/utils/oss_mock.py b/tests/providers/alibaba/cloud/log/__init__.py
similarity index 60%
copy from tests/providers/alibaba/cloud/utils/oss_mock.py
copy to tests/providers/alibaba/cloud/log/__init__.py
index 7bbef54..13a8339 100644
--- a/tests/providers/alibaba/cloud/utils/oss_mock.py
+++ b/tests/providers/alibaba/cloud/log/__init__.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,22 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
-
-from airflow.models import Connection
-
-OSS_PROJECT_ID_HOOK_UNIT_TEST = 'example-project'
-
-
-def mock_oss_hook_default_project_id(self, oss_conn_id='mock_oss_default', region='mock_region'):
- self.oss_conn_id = oss_conn_id
- self.oss_conn = Connection(
- extra=json.dumps(
- {
- 'auth_type': 'AK',
- 'access_key_id': 'mock_access_key_id',
- 'access_key_secret': 'mock_access_key_secret',
- }
- )
- )
- self.region = region
diff --git a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
new file mode 100644
index 0000000..24eb73b
--- /dev/null
+++ b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import unittest
+from unittest import mock
+from unittest.mock import PropertyMock
+
+from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSTaskHandler
+
+OSS_TASK_HANDLER_STRING = 'airflow.providers.alibaba.cloud.log.oss_task_handler.{}'
+MOCK_OSS_CONN_ID = 'mock_id'
+MOCK_BUCKET_NAME = 'mock_bucket_name'
+MOCK_KEY = 'mock_key'
+MOCK_KEYS = ['mock_key1', 'mock_key2', 'mock_key3']
+MOCK_CONTENT = 'mock_content'
+MOCK_FILE_PATH = 'mock_file_path'
+
+
+class TestOSSTaskHandler(unittest.TestCase):
+ def setUp(self):
+ self.base_log_folder = 'local/airflow/logs/1.log'
+ self.oss_log_folder = f'oss://{MOCK_BUCKET_NAME}/airflow/logs'
+ self.filename_template = '{try_number}.log'
+ self.oss_task_handler = OSSTaskHandler(
+ self.base_log_folder, self.oss_log_folder, self.filename_template
+ )
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('conf.get'))
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSHook'))
+ def test_hook(self, mock_service, mock_conf_get):
+ # Given
+ mock_conf_get.return_value = 'oss_default'
+
+ # When
+ self.oss_task_handler.hook
+
+ # Then
+ mock_conf_get.assert_called_once_with('logging', 'REMOTE_LOG_CONN_ID')
+ mock_service.assert_called_once_with(oss_conn_id='oss_default')
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_log_exists(self, mock_service):
+ self.oss_task_handler.oss_log_exists('1.log')
+ mock_service.assert_called_once_with()
+ mock_service.return_value.key_exist.assert_called_once_with(MOCK_BUCKET_NAME, 'airflow/logs/1.log')
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_read(self, mock_service):
+ self.oss_task_handler.oss_read('1.log')
+ mock_service.assert_called_once_with()
+ mock_service.return_value.read_key(MOCK_BUCKET_NAME, 'airflow/logs/1.log')
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.oss_log_exists'))
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_write_into_remote_existing_file_via_append(self, mock_service, mock_oss_log_exists):
+ # Given
+ mock_oss_log_exists.return_value = True
+ mock_service.return_value.head_key.return_value.content_length = 1
+
+ # When
+ self.oss_task_handler.oss_write(MOCK_CONTENT, '1.log', append=True)
+
+ # Then
+ assert mock_service.call_count == 2
+ mock_service.return_value.head_key.assert_called_once_with(MOCK_BUCKET_NAME, 'airflow/logs/1.log')
+ mock_oss_log_exists.assert_called_once_with('airflow/logs/1.log')
+ mock_service.return_value.append_string.assert_called_once_with(
+ MOCK_BUCKET_NAME, MOCK_CONTENT, 'airflow/logs/1.log', 1
+ )
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.oss_log_exists'))
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_write_into_remote_non_existing_file_via_append(self, mock_service, mock_oss_log_exists):
+ # Given
+ mock_oss_log_exists.return_value = False
+
+ # When
+ self.oss_task_handler.oss_write(MOCK_CONTENT, '1.log', append=True)
+
+ # Then
+ assert mock_service.call_count == 1
+ mock_service.return_value.head_key.assert_not_called()
+ mock_oss_log_exists.assert_called_once_with('airflow/logs/1.log')
+ mock_service.return_value.append_string.assert_called_once_with(
+ MOCK_BUCKET_NAME, MOCK_CONTENT, 'airflow/logs/1.log', 0
+ )
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.oss_log_exists'))
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_write_into_remote_existing_file_not_via_append(self, mock_service, mock_oss_log_exists):
+ # Given
+ mock_oss_log_exists.return_value = True
+
+ # When
+ self.oss_task_handler.oss_write(MOCK_CONTENT, '1.log', append=False)
+
+ # Then
+ assert mock_service.call_count == 1
+ mock_service.return_value.head_key.assert_not_called()
+ mock_oss_log_exists.assert_not_called()
+ mock_service.return_value.append_string.assert_called_once_with(
+ MOCK_BUCKET_NAME, MOCK_CONTENT, 'airflow/logs/1.log', 0
+ )
+
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.oss_log_exists'))
+ @mock.patch(OSS_TASK_HANDLER_STRING.format('OSSTaskHandler.hook'), new_callable=PropertyMock)
+ def test_oss_write_into_remote_non_existing_file_not_via_append(self, mock_service, mock_oss_log_exists):
+ # Given
+ mock_oss_log_exists.return_value = False
+
+ # When
+ self.oss_task_handler.oss_write(MOCK_CONTENT, '1.log', append=False)
+
+ # Then
+ assert mock_service.call_count == 1
+ mock_service.return_value.head_key.assert_not_called()
+ mock_oss_log_exists.assert_not_called()
+ mock_service.return_value.append_string.assert_called_once_with(
+ MOCK_BUCKET_NAME, MOCK_CONTENT, 'airflow/logs/1.log', 0
+ )
diff --git a/tests/providers/alibaba/cloud/utils/oss_mock.py b/tests/providers/alibaba/cloud/utils/oss_mock.py
index 7bbef54..a4e1346 100644
--- a/tests/providers/alibaba/cloud/utils/oss_mock.py
+++ b/tests/providers/alibaba/cloud/utils/oss_mock.py
@@ -30,6 +30,7 @@ def mock_oss_hook_default_project_id(self, oss_conn_id='mock_oss_default', regio
'auth_type': 'AK',
'access_key_id': 'mock_access_key_id',
'access_key_secret': 'mock_access_key_secret',
+ 'region': 'mock_region',
}
)
)