You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/01/23 12:52:29 UTC
[airflow] branch master updated: Upgrade azure blob to v12 (#12188)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 94b1531 Upgrade azure blob to v12 (#12188)
94b1531 is described below
commit 94b1531230231c57610d720e59563ccd98e7ecb2
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Jan 23 13:52:13 2021 +0100
Upgrade azure blob to v12 (#12188)
---
BREEZE.rst | 4 +-
airflow/providers/microsoft/azure/hooks/wasb.py | 262 ++++++++++++++----
breeze | 2 +-
scripts/docker/install_airflow.sh | 6 +
scripts/in_container/_in_container_utils.sh | 3 +-
setup.py | 2 +-
tests/providers/microsoft/azure/hooks/test_wasb.py | 308 +++++++++++++--------
.../microsoft/azure/log/test_wasb_task_handler.py | 19 +-
8 files changed, 416 insertions(+), 190 deletions(-)
diff --git a/BREEZE.rst b/BREEZE.rst
index 298be56..3dd95df 100644
--- a/BREEZE.rst
+++ b/BREEZE.rst
@@ -1705,7 +1705,7 @@ This is the current syntax for `./breeze <./breeze>`_:
wheel,sdist,both
- Default:
+ Default: wheel
-v, --verbose
Show verbose information about executed docker, kind, kubectl, helm commands. Useful for
@@ -2132,7 +2132,7 @@ This is the current syntax for `./breeze <./breeze>`_:
wheel,sdist,both
- Default:
+ Default: wheel
-S, --version-suffix-for-pypi SUFFIX
Adds optional suffix to the version in the generated backport package. It can be used
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py
index af1f850..7758ef0 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -23,18 +23,14 @@ It communicate via the Window Azure Storage Blob protocol. Make sure that a
Airflow connection of type `wasb` exists. Authorization can be done by supplying a
login (=Storage account name) and password (=KEY), or login and SAS token in the extra
field (see connection `wasb_default` for an example).
+
"""
-try:
- from azure.storage.blob import BlockBlobService
-except ImportError:
- # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
- # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
- # fine together with `snowflake` because the deprecated library does not overlap with the
- # new libraries except the `blob` classes. So while `azure` works fine for most cases
- # blob is the only exception
- # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
- # Once this is merged, this should remove the ImportError handling
- BlockBlobService = None
+
+from typing import Any, Dict, List, Optional
+
+from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
+from azure.identity import ClientSecretCredential
+from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient, StorageStreamDownloader
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
@@ -52,6 +48,8 @@ class WasbHook(BaseHook):
:param wasb_conn_id: Reference to the wasb connection.
:type wasb_conn_id: str
+ :param public_read: Whether an anonymous public read access should be used. default is False
+ :type public_read: bool
"""
conn_name_attr = 'wasb_conn_id'
@@ -59,18 +57,69 @@ class WasbHook(BaseHook):
conn_type = 'wasb'
hook_name = 'Azure Blob Storage'
- def __init__(self, wasb_conn_id: str = default_conn_name) -> None:
+ def __init__(self, wasb_conn_id: str = default_conn_name, public_read: bool = False) -> None:
super().__init__()
self.conn_id = wasb_conn_id
+ self.public_read = public_read
self.connection = self.get_conn()
- def get_conn(self) -> BlockBlobService:
- """Return the BlockBlobService object."""
+ def get_conn(self) -> BlobServiceClient: # pylint: disable=too-many-return-statements
+ """Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
- service_options = conn.extra_dejson
- return BlockBlobService(account_name=conn.login, account_key=conn.password, **service_options)
+ extra = conn.extra_dejson or {}
+
+ if self.public_read:
+ # Here we use anonymous public read
+ # more info
+ # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+ return BlobServiceClient(account_url=conn.host)
+
+ if extra.get('connection_string'):
+ # connection_string auth takes priority
+ return BlobServiceClient.from_connection_string(extra.get('connection_string'))
+ if extra.get('shared_access_key'):
+ # using shared access key
+ return BlobServiceClient(account_url=conn.host, credential=extra.get('shared_access_key'))
+ if extra.get('tenant_id'):
+ # use Active Directory auth
+ app_id = conn.login
+ app_secret = conn.password
+ token_credential = ClientSecretCredential(extra.get('tenant_id'), app_id, app_secret)
+ return BlobServiceClient(account_url=conn.host, credential=token_credential)
+ sas_token = extra.get('sas_token')
+ if sas_token and sas_token.startswith('https'):
+ return BlobServiceClient(account_url=extra.get('sas_token'))
+ if sas_token and not sas_token.startswith('https'):
+ return BlobServiceClient(account_url=f"https://{conn.login}.blob.core.windows.net/" + sas_token)
+ else:
+ # Fall back to old auth
+ return BlobServiceClient(
+ account_url=f"https://{conn.login}.blob.core.windows.net/", credential=conn.password, **extra
+ )
+
+ def _get_container_client(self, container_name: str) -> ContainerClient:
+ """
+ Instantiates a container client
+
+ :param container_name: The name of the container
+ :type container_name: str
+ :return: ContainerClient
+ """
+ return self.connection.get_container_client(container_name)
- def check_for_blob(self, container_name, blob_name, **kwargs):
+ def _get_blob_client(self, container_name: str, blob_name: str) -> BlobClient:
+ """
+ Instantiates a blob client
+
+ :param container_name: The name of the blob container
+ :type container_name: str
+ :param blob_name: The name of the blob. This needs not be existing
+ :type blob_name: str
+ """
+ container_client = self.create_container(container_name)
+ return container_client.get_blob_client(blob_name)
+
+ def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool:
"""
Check if a blob exists on Azure Blob Storage.
@@ -78,15 +127,18 @@ class WasbHook(BaseHook):
:type container_name: str
:param blob_name: Name of the blob.
:type blob_name: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.exists()` takes.
+ :param kwargs: Optional keyword arguments for ``BlobClient.get_blob_properties`` takes.
:type kwargs: object
:return: True if the blob exists, False otherwise.
:rtype: bool
"""
- return self.connection.exists(container_name, blob_name, **kwargs)
+ try:
+ self._get_blob_client(container_name, blob_name).get_blob_properties(**kwargs)
+ except ResourceNotFoundError:
+ return False
+ return True
- def check_for_prefix(self, container_name: str, prefix: str, **kwargs) -> bool:
+ def check_for_prefix(self, container_name: str, prefix: str, **kwargs):
"""
Check if a prefix exists on Azure Blob storage.
@@ -94,31 +146,43 @@ class WasbHook(BaseHook):
:type container_name: str
:param prefix: Prefix of the blob.
:type prefix: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.list_blobs()` takes.
+ :param kwargs: Optional keyword arguments that ``ContainerClient.walk_blobs`` takes
:type kwargs: object
:return: True if blobs matching the prefix exist, False otherwise.
:rtype: bool
"""
- matches = self.connection.list_blobs(container_name, prefix, num_results=1, **kwargs)
- return len(list(matches)) > 0
+ blobs = self.get_blobs_list(container_name=container_name, prefix=prefix, **kwargs)
+ return len(blobs) > 0
- def get_blobs_list(self, container_name: str, prefix: str, **kwargs) -> list:
+ def get_blobs_list(
+ self,
+ container_name: str,
+ prefix: Optional[str] = None,
+ include: Optional[List[str]] = None,
+ delimiter: Optional[str] = '/',
+ **kwargs,
+ ) -> List:
"""
- Return a list of blobs from path defined in prefix param
+ List blobs in a given container
- :param container_name: Name of the container.
+ :param container_name: The name of the container
:type container_name: str
- :param prefix: Prefix of the blob.
+ :param prefix: Filters the results to return only blobs whose names
+ begin with the specified prefix.
:type prefix: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.list_blobs()` takes (num_results, include,
- delimiter, marker, timeout)
- :type kwargs: object
- :return: List of blobs.
- :rtype: list(azure.storage.common.models.ListGenerator)
+ :param include: Specifies one or more additional datasets to include in the
+ response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
+ ``copy`, ``deleted``.
+ :type include: List[str]
+ :param delimiter: filters objects based on the delimiter (for e.g '.csv')
+ :type delimiter: str
"""
- return self.connection.list_blobs(container_name, prefix, **kwargs)
+ container = self._get_container_client(container_name)
+ blob_list = []
+ blobs = container.walk_blobs(name_starts_with=prefix, include=include, delimiter=delimiter, **kwargs)
+ for blob in blobs:
+ blob_list.append(blob.name)
+ return blob_list
def load_file(self, file_path: str, container_name: str, blob_name: str, **kwargs) -> None:
"""
@@ -130,12 +194,11 @@ class WasbHook(BaseHook):
:type container_name: str
:param blob_name: Name of the blob.
:type blob_name: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.create_blob_from_path()` takes.
+ :param kwargs: Optional keyword arguments that ``BlobClient.upload_blob()`` takes.
:type kwargs: object
"""
- # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_file.
- self.connection.create_blob_from_path(container_name, blob_name, file_path, **kwargs)
+ with open(file_path, 'rb') as data:
+ self.upload(container_name=container_name, blob_name=blob_name, data=data, **kwargs)
def load_string(self, string_data: str, container_name: str, blob_name: str, **kwargs) -> None:
"""
@@ -147,12 +210,11 @@ class WasbHook(BaseHook):
:type container_name: str
:param blob_name: Name of the blob.
:type blob_name: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.create_blob_from_text()` takes.
+ :param kwargs: Optional keyword arguments that ``BlobClient.upload()`` takes.
:type kwargs: object
"""
# Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_string.
- self.connection.create_blob_from_text(container_name, blob_name, string_data, **kwargs)
+ self.upload(container_name, blob_name, string_data, **kwargs)
def get_file(self, file_path: str, container_name: str, blob_name: str, **kwargs):
"""
@@ -164,11 +226,12 @@ class WasbHook(BaseHook):
:type container_name: str
:param blob_name: Name of the blob.
:type blob_name: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.create_blob_from_path()` takes.
+ :param kwargs: Optional keyword arguments that `BlobClient.download_blob()` takes.
:type kwargs: object
"""
- return self.connection.get_blob_to_path(container_name, blob_name, file_path, **kwargs)
+ with open(file_path, "wb") as fileblob:
+ stream = self.download(container_name=container_name, blob_name=blob_name, **kwargs)
+ fileblob.write(stream.readall())
def read_file(self, container_name: str, blob_name: str, **kwargs):
"""
@@ -178,11 +241,100 @@ class WasbHook(BaseHook):
:type container_name: str
:param blob_name: Name of the blob.
:type blob_name: str
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.create_blob_from_path()` takes.
+ :param kwargs: Optional keyword arguments that `BlobClient.download_blob` takes.
:type kwargs: object
"""
- return self.connection.get_blob_to_text(container_name, blob_name, **kwargs).content
+ return self.download(container_name, blob_name, **kwargs).readall()
+
+ def upload(
+ self,
+ container_name,
+ blob_name,
+ data,
+ blob_type: str = 'BlockBlob',
+ length: Optional[int] = None,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ """
+ Creates a new blob from a data source with automatic chunking.
+
+ :param container_name: The name of the container to upload data
+ :type container_name: str
+ :param blob_name: The name of the blob to upload. This need not exist in the container
+ :type blob_name: str
+ :param data: The blob data to upload
+ :param blob_type: The type of the blob. This can be either ``BlockBlob``,
+ ``PageBlob`` or ``AppendBlob``. The default value is ``BlockBlob``.
+ :type blob_type: storage.BlobType
+ :param length: Number of bytes to read from the stream. This is optional,
+ but should be supplied for optimal performance.
+ :type length: int
+ """
+ blob_client = self._get_blob_client(container_name, blob_name)
+ return blob_client.upload_blob(data, blob_type, length=length, **kwargs)
+
+ def download(
+ self, container_name, blob_name, offset: Optional[int] = None, length: Optional[int] = None, **kwargs
+ ) -> StorageStreamDownloader:
+ """
+ Downloads a blob to the StorageStreamDownloader
+
+ :param container_name: The name of the container containing the blob
+ :type container_name: str
+ :param blob_name: The name of the blob to download
+ :type blob_name: str
+ :param offset: Start of byte range to use for downloading a section of the blob.
+ Must be set if length is provided.
+ :type offset: int
+ :param length: Number of bytes to read from the stream.
+ :type length: int
+ """
+ blob_client = self._get_blob_client(container_name, blob_name)
+ return blob_client.download_blob(offset=offset, length=length, **kwargs)
+
+ def create_container(self, container_name: str) -> ContainerClient:
+ """
+ Create container object if not already existing
+
+ :param container_name: The name of the container to create
+ :type container_name: str
+ """
+ container_client = self._get_container_client(container_name)
+ try:
+ self.log.info('Attempting to create container: %s', container_name)
+ container_client.create_container()
+ self.log.info("Created container: %s", container_name)
+ return container_client
+ except ResourceExistsError:
+ self.log.info("Container %s already exists", container_name)
+ return container_client
+
+ def delete_container(self, container_name: str) -> None:
+ """
+ Delete a container object
+
+ :param container_name: The name of the container
+ :type container_name: str
+ """
+ try:
+ self.log.info('Attempting to delete container: %s', container_name)
+ self._get_container_client(container_name).delete_container()
+ self.log.info('Deleted container: %s', container_name)
+ except ResourceNotFoundError:
+ self.log.info('Container %s not found', container_name)
+
+ def delete_blobs(self, container_name: str, *blobs, **kwargs) -> None:
+ """
+ Marks the specified blobs or snapshots for deletion.
+
+ :param container_name: The name of the container containing the blobs
+ :type container_name: str
+ :param blobs: The blobs to delete. This can be a single blob, or multiple values
+ can be supplied, where each value is either the name of the blob (str) or BlobProperties.
+ :type blobs: Union[str, BlobProperties]
+ """
+ self._get_container_client(container_name).delete_blobs(*blobs, **kwargs)
+ self.log.info("Deleted blobs: %s", blobs)
def delete_file(
self,
@@ -204,22 +356,16 @@ class WasbHook(BaseHook):
:param ignore_if_missing: if True, then return success even if the
blob does not exist.
:type ignore_if_missing: bool
- :param kwargs: Optional keyword arguments that
- `BlockBlobService.create_blob_from_path()` takes.
+ :param kwargs: Optional keyword arguments that ``ContainerClient.delete_blobs()`` takes.
:type kwargs: object
"""
if is_prefix:
- blobs_to_delete = [
- blob.name for blob in self.connection.list_blobs(container_name, prefix=blob_name, **kwargs)
- ]
+ blobs_to_delete = self.get_blobs_list(container_name, prefix=blob_name, **kwargs)
elif self.check_for_blob(container_name, blob_name):
blobs_to_delete = [blob_name]
else:
blobs_to_delete = []
-
if not ignore_if_missing and len(blobs_to_delete) == 0:
raise AirflowException(f'Blob(s) not found: {blob_name}')
- for blob_uri in blobs_to_delete:
- self.log.info("Deleting blob: %s", blob_uri)
- self.connection.delete_blob(container_name, blob_uri, delete_snapshots='include', **kwargs)
+ self.delete_blobs(container_name, *blobs_to_delete, **kwargs)
diff --git a/breeze b/breeze
index 8732405..8204184 100755
--- a/breeze
+++ b/breeze
@@ -2411,7 +2411,7 @@ function breeze::flag_packages() {
${FORMATTED_PACKAGE_FORMATS}
- Default: ${_breeze_default_package_formats:=}
+ Default: ${_breeze_default_package_format:=}
"
}
diff --git a/scripts/docker/install_airflow.sh b/scripts/docker/install_airflow.sh
index cf6c8ed..91f3860 100755
--- a/scripts/docker/install_airflow.sh
+++ b/scripts/docker/install_airflow.sh
@@ -63,6 +63,9 @@ function install_airflow() {
pip install ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_INSTALL_VERSION}"
fi
+ # Work around to install azure-storage-blob
+ pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+ pip install azure-storage-blob azure-storage-file
# make sure correct PIP version is used
pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
@@ -79,6 +82,9 @@ function install_airflow() {
pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade --upgrade-strategy only-if-needed \
${AIRFLOW_INSTALL_EDITABLE_FLAG} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_INSTALL_VERSION}" \
+ # Work around to install azure-storage-blob
+ pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+ pip install azure-storage-blob azure-storage-file
# make sure correct PIP version is used
pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
diff --git a/scripts/in_container/_in_container_utils.sh b/scripts/in_container/_in_container_utils.sh
index d18cd33..d291032 100644
--- a/scripts/in_container/_in_container_utils.sh
+++ b/scripts/in_container/_in_container_utils.sh
@@ -310,7 +310,8 @@ function reinstall_azure_storage_blob() {
echo
echo "Reinstalling azure-storage-blob"
echo
- pip install azure-storage-blob --no-deps --force-reinstall
+ pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+ pip install azure-storage-blob azure-storage-file --no-deps --force-reinstall
group_end
}
diff --git a/setup.py b/setup.py
index 380fde7..ba28546 100644
--- a/setup.py
+++ b/setup.py
@@ -218,7 +218,7 @@ azure = [
'azure-mgmt-containerinstance>=1.5.0,<2.0',
'azure-mgmt-datalake-store>=0.5.0',
'azure-mgmt-resource>=2.2.0',
- 'azure-storage>=0.34.0, <0.37.0',
+ 'azure-storage-file>=2.1.0',
]
cassandra = [
'cassandra-driver>=3.13.0,<3.21.0',
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 16556a1..65c2401 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -20,155 +20,245 @@
import json
import unittest
-from collections import namedtuple
from unittest import mock
import pytest
+from azure.storage.blob import BlobServiceClient
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
from airflow.utils import db
-try:
- from azure.storage.blob import BlockBlobService
-except ImportError:
- # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
- # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
- # fine together with `snowflake` because the deprecated library does not overlap with the
- # new libraries except the `blob` classes. So while `azure` works fine for most cases
- # blob is the only exception
- # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
- # Once this is merged, we can remove the xfail below and this ImportError handling
- BlockBlobService = None
+# connection_string has a format
+CONN_STRING = (
+ 'DefaultEndpointsProtocol=https;AccountName=testname;AccountKey=wK7BOz;EndpointSuffix=core.windows.net'
+)
+
+ACCESS_KEY_STRING = "AccountName=name;skdkskd"
-@pytest.mark.xfail
class TestWasbHook(unittest.TestCase):
def setUp(self):
db.merge_conn(Connection(conn_id='wasb_test_key', conn_type='wasb', login='login', password='key'))
+ self.connection_type = 'wasb'
+ self.connection_string_id = 'azure_test_connection_string'
+ self.shared_key_conn_id = 'azure_shared_key_test'
+ self.ad_conn_id = 'azure_AD_test'
+ self.sas_conn_id = 'sas_token_id'
+ self.public_read_conn_id = 'pub_read_id'
+
+ db.merge_conn(
+ Connection(
+ conn_id=self.public_read_conn_id,
+ conn_type=self.connection_type,
+ host='https://accountname.blob.core.windows.net',
+ )
+ )
+
+ db.merge_conn(
+ Connection(
+ conn_id=self.connection_string_id,
+ conn_type=self.connection_type,
+ extra=json.dumps({'connection_string': CONN_STRING}),
+ )
+ )
db.merge_conn(
Connection(
- conn_id='wasb_test_sas_token',
- conn_type='wasb',
- login='login',
+ conn_id=self.shared_key_conn_id,
+ conn_type=self.connection_type,
+ host='https://accountname.blob.core.windows.net',
+ extra=json.dumps({'shared_access_key': 'token'}),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=self.ad_conn_id,
+ conn_type=self.connection_type,
+ extra=json.dumps(
+ {'tenant_id': 'token', 'application_id': 'appID', 'application_secret': "appsecret"}
+ ),
+ )
+ )
+ db.merge_conn(
+ Connection(
+ conn_id=self.sas_conn_id,
+ conn_type=self.connection_type,
extra=json.dumps({'sas_token': 'token'}),
)
)
def test_key(self):
hook = WasbHook(wasb_conn_id='wasb_test_key')
- assert hook.conn_id == 'wasb_test_key'
- assert isinstance(hook.connection, BlockBlobService)
+ self.assertEqual(hook.conn_id, 'wasb_test_key')
+ self.assertIsInstance(hook.connection, BlobServiceClient)
+
+ def test_public_read(self):
+ hook = WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True)
+ assert isinstance(hook.get_conn(), BlobServiceClient)
+
+ def test_connection_string(self):
+ hook = WasbHook(wasb_conn_id=self.connection_string_id)
+ assert hook.conn_id == self.connection_string_id
+ assert isinstance(hook.get_conn(), BlobServiceClient)
- def test_sas_token(self):
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- assert hook.conn_id == 'wasb_test_sas_token'
- assert isinstance(hook.connection, BlockBlobService)
+ def test_shared_key_connection(self):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ assert isinstance(hook.get_conn(), BlobServiceClient)
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
+ def test_sas_token_connection(self):
+ hook = WasbHook(wasb_conn_id=self.sas_conn_id)
+ assert isinstance(hook.get_conn(), BlobServiceClient)
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
def test_check_for_blob(self, mock_service):
- mock_instance = mock_service.return_value
- mock_instance.exists.return_value = True
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- assert hook.check_for_blob('container', 'blob', timeout=3)
- mock_instance.exists.assert_called_once_with('container', 'blob', timeout=3)
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_check_for_blob_empty(self, mock_service):
- mock_service.return_value.exists.return_value = False
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- assert not hook.check_for_blob('container', 'blob')
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_check_for_prefix(self, mock_service):
- mock_instance = mock_service.return_value
- mock_instance.list_blobs.return_value = iter(['blob_1'])
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ assert hook.check_for_blob(container_name='mycontainer', blob_name='myblob')
+ mock_container_client = mock_service.return_value.get_container_client
+ mock_container_client.assert_called_once_with('mycontainer')
+ mock_container_client.return_value.get_blob_client.assert_called_once_with('myblob')
+ mock_container_client.return_value.get_blob_client.return_value.get_blob_properties.assert_called()
+
+ @mock.patch.object(WasbHook, 'get_blobs_list')
+ def test_check_for_prefix(self, get_blobs_list):
+ get_blobs_list.return_value = ['blobs']
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
assert hook.check_for_prefix('container', 'prefix', timeout=3)
- mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_check_for_prefix_empty(self, mock_service):
- mock_instance = mock_service.return_value
- mock_instance.list_blobs.return_value = iter([])
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- assert not hook.check_for_prefix('container', 'prefix')
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_load_file(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- hook.load_file('path', 'container', 'blob', max_connections=1)
- mock_instance.create_blob_from_path.assert_called_once_with(
- 'container', 'blob', 'path', max_connections=1
+ get_blobs_list.assert_called_once_with(container_name='container', prefix='prefix', timeout=3)
+
+ @mock.patch.object(WasbHook, 'get_blobs_list')
+ def test_check_for_prefix_empty(self, get_blobs_list):
+ get_blobs_list.return_value = []
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ assert not hook.check_for_prefix('container', 'prefix', timeout=3)
+ get_blobs_list.assert_called_once_with(container_name='container', prefix='prefix', timeout=3)
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_get_blobs_list(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.get_blobs_list(container_name='mycontainer', prefix='my', include=None, delimiter='/')
+ mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+ mock_service.return_value.get_container_client.return_value.walk_blobs.assert_called_once_with(
+ name_starts_with='my', include=None, delimiter='/'
)
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_load_string(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ @mock.patch.object(WasbHook, 'upload')
+ def test_load_file(self, mock_upload):
+ with mock.patch("builtins.open", mock.mock_open(read_data="data")):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.load_file('path', 'container', 'blob', max_connections=1)
+ mock_upload.assert_called()
+
+ @mock.patch.object(WasbHook, 'upload')
+ def test_load_string(self, mock_upload):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.load_string('big string', 'container', 'blob', max_connections=1)
- mock_instance.create_blob_from_text.assert_called_once_with(
- 'container', 'blob', 'big string', max_connections=1
- )
+ mock_upload.assert_called_once_with('container', 'blob', 'big string', max_connections=1)
+
+ @mock.patch.object(WasbHook, 'download')
+ def test_get_file(self, mock_download):
+ with mock.patch("builtins.open", mock.mock_open(read_data="data")):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.get_file('path', 'container', 'blob', max_connections=1)
+ mock_download.assert_called_once_with(container_name='container', blob_name='blob', max_connections=1)
+ mock_download.return_value.readall.assert_called()
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_get_file(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- hook.get_file('path', 'container', 'blob', max_connections=1)
- mock_instance.get_blob_to_path.assert_called_once_with('container', 'blob', 'path', max_connections=1)
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_read_file(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ @mock.patch.object(WasbHook, 'download')
+ def test_read_file(self, mock_download, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.read_file('container', 'blob', max_connections=1)
- mock_instance.get_blob_to_text.assert_called_once_with('container', 'blob', max_connections=1)
+ mock_download.assert_called_once_with('container', 'blob', max_connections=1)
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_upload(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.upload(
+ container_name='mycontainer', blob_name='myblob', data=b'mydata', blob_type='BlockBlob', length=4
+ )
+ mock_cn_client = mock_service.return_value.get_container_client
+ mock_cn_client.assert_called_once_with('mycontainer')
+ mock_cn_client.return_value.get_blob_client.assert_called_once_with('myblob')
+ mock_cn_client.return_value.get_blob_client.return_value.upload_blob.assert_called_once_with(
+ b'mydata', 'BlockBlob', length=4
+ )
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_download(self, mock_service):
+ container_client = mock_service.return_value.get_container_client
+ blob_client = container_client.return_value.get_blob_client
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.download(container_name='mycontainer', blob_name='myblob', offset=2, length=4)
+ container_client.assert_called_once_with('mycontainer')
+ blob_client.assert_called_once_with('myblob')
+ blob_client.return_value.download_blob.assert_called_once_with(offset=2, length=4)
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_get_container_client(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook._get_container_client('mycontainer')
+ mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_get_blob_client(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook._get_blob_client(container_name='mycontainer', blob_name='myblob')
+ mock_instance = mock_service.return_value.get_container_client
+ mock_instance.assert_called_once_with('mycontainer')
+ mock_instance.return_value.get_blob_client.assert_called_once_with('myblob')
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_create_container(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.create_container(container_name='mycontainer')
+ mock_instance = mock_service.return_value.get_container_client
+ mock_instance.assert_called_once_with('mycontainer')
+ mock_instance.return_value.create_container.assert_called()
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_delete_single_blob(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ def test_delete_container(self, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+ hook.delete_container('mycontainer')
+ mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+ mock_service.return_value.get_container_client.return_value.delete_container.assert_called()
+
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ @mock.patch.object(WasbHook, 'delete_blobs')
+ def test_delete_single_blob(self, delete_blobs, mock_service):
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_file('container', 'blob', is_prefix=False)
- mock_instance.delete_blob.assert_called_once_with('container', 'blob', delete_snapshots='include')
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_delete_multiple_blobs(self, mock_service):
- mock_instance = mock_service.return_value
- Blob = namedtuple('Blob', ['name'])
- mock_instance.list_blobs.return_value = iter([Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')])
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ delete_blobs.assert_called_once_with('container', 'blob')
+
+ @mock.patch.object(WasbHook, 'delete_blobs')
+ @mock.patch.object(WasbHook, 'get_blobs_list')
+ @mock.patch.object(WasbHook, 'check_for_blob')
+ def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist, mock_delete_blobs):
+ mock_check.return_value = False
+ mock_get_blobslist.return_value = ['blob_prefix/blob1', 'blob_prefix/blob2']
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_file('container', 'blob_prefix', is_prefix=True)
- mock_instance.delete_blob.assert_any_call(
- 'container', 'blob_prefix/blob1', delete_snapshots='include'
- )
- mock_instance.delete_blob.assert_any_call(
- 'container', 'blob_prefix/blob2', delete_snapshots='include'
+ mock_get_blobslist.assert_called_once_with('container', prefix='blob_prefix')
+ mock_delete_blobs.assert_any_call(
+ 'container',
+ 'blob_prefix/blob1',
+ 'blob_prefix/blob2',
)
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_delete_nonexisting_blob_fails(self, mock_service):
- mock_instance = mock_service.return_value
- mock_instance.exists.return_value = False
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+ @mock.patch.object(WasbHook, 'get_blobs_list')
+ @mock.patch.object(WasbHook, 'check_for_blob')
+ def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs, mock_service):
+ mock_getblobs.return_value = []
+ mock_check.return_value = False
with pytest.raises(Exception) as ctx:
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_file('container', 'nonexisting_blob', is_prefix=False, ignore_if_missing=False)
assert isinstance(ctx.value, AirflowException)
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_delete_multiple_nonexisting_blobs_fails(self, mock_service):
- mock_instance = mock_service.return_value
- mock_instance.list_blobs.return_value = iter([])
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+ @mock.patch.object(WasbHook, 'get_blobs_list')
+ def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs):
+ mock_getblobs.return_value = []
with pytest.raises(Exception) as ctx:
+ hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
hook.delete_file('container', 'nonexisting_blob_prefix', is_prefix=True, ignore_if_missing=False)
assert isinstance(ctx.value, AirflowException)
-
- @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
- def test_get_blobs_list(self, mock_service):
- mock_instance = mock_service.return_value
- hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
- hook.get_blobs_list('container', 'prefix', num_results=1, timeout=3)
- mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
index f859649..8e517e0 100644
--- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
+++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
@@ -19,7 +19,6 @@ import unittest
from datetime import datetime
from unittest import mock
-import pytest
from azure.common import AzureHttpError
from airflow.models import DAG, TaskInstance
@@ -55,7 +54,7 @@ class TestWasbTaskHandler(unittest.TestCase):
self.addCleanup(self.dag.clear)
@conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'})
- @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService")
+ @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
def test_hook(self, mock_service):
assert isinstance(self.wasb_task_handler.hook, WasbHook)
@@ -86,14 +85,6 @@ class TestWasbTaskHandler(unittest.TestCase):
self.wasb_task_handler.set_context(self.ti)
assert self.wasb_task_handler.upload_on_close
- # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
- # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
- # fine together with `snowflake` because the deprecated library does not overlap with the
- # new libraries except the `blob` classes. So while `azure` works fine for most cases
- # blob is the only exception
- # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
- # Once this is merged, we can remove the xfail
- @pytest.mark.xfail
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook")
def test_wasb_log_exists(self, mock_hook):
instance = mock_hook.return_value
@@ -103,14 +94,6 @@ class TestWasbTaskHandler(unittest.TestCase):
self.container_name, self.remote_log_location
)
- # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
- # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
- # fine together with `snowflake` because the deprecated library does not overlap with the
- # new libraries except the `blob` classes. So while `azure` works fine for most cases
- # blob is the only exception
- # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
- # Once this is merged, we can remove the xfail
- @pytest.mark.xfail
@mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook")
def test_wasb_read(self, mock_hook):
mock_hook.return_value.read_file.return_value = 'Log line'