You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/01/23 12:52:29 UTC

[airflow] branch master updated: Upgrade azure blob to v12 (#12188)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 94b1531  Upgrade azure blob to v12 (#12188)
94b1531 is described below

commit 94b1531230231c57610d720e59563ccd98e7ecb2
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Sat Jan 23 13:52:13 2021 +0100

    Upgrade azure blob to v12 (#12188)
---
 BREEZE.rst                                         |   4 +-
 airflow/providers/microsoft/azure/hooks/wasb.py    | 262 ++++++++++++++----
 breeze                                             |   2 +-
 scripts/docker/install_airflow.sh                  |   6 +
 scripts/in_container/_in_container_utils.sh        |   3 +-
 setup.py                                           |   2 +-
 tests/providers/microsoft/azure/hooks/test_wasb.py | 308 +++++++++++++--------
 .../microsoft/azure/log/test_wasb_task_handler.py  |  19 +-
 8 files changed, 416 insertions(+), 190 deletions(-)

diff --git a/BREEZE.rst b/BREEZE.rst
index 298be56..3dd95df 100644
--- a/BREEZE.rst
+++ b/BREEZE.rst
@@ -1705,7 +1705,7 @@ This is the current syntax for  `./breeze <./breeze>`_:
 
                  wheel,sdist,both
 
-          Default: 
+          Default: wheel
 
   -v, --verbose
           Show verbose information about executed docker, kind, kubectl, helm commands. Useful for
@@ -2132,7 +2132,7 @@ This is the current syntax for  `./breeze <./breeze>`_:
 
                  wheel,sdist,both
 
-          Default: 
+          Default: wheel
 
   -S, --version-suffix-for-pypi SUFFIX
           Adds optional suffix to the version in the generated backport package. It can be used
diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py
index af1f850..7758ef0 100644
--- a/airflow/providers/microsoft/azure/hooks/wasb.py
+++ b/airflow/providers/microsoft/azure/hooks/wasb.py
@@ -23,18 +23,14 @@ It communicate via the Window Azure Storage Blob protocol. Make sure that a
 Airflow connection of type `wasb` exists. Authorization can be done by supplying a
 login (=Storage account name) and password (=KEY), or login and SAS token in the extra
 field (see connection `wasb_default` for an example).
+
 """
-try:
-    from azure.storage.blob import BlockBlobService
-except ImportError:
-    # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
-    # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
-    # fine together with `snowflake` because the deprecated library does not overlap with the
-    # new libraries except the `blob` classes. So while `azure` works fine for most cases
-    # blob is the only exception
-    # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
-    # Once this is merged, this should remove the ImportError handling
-    BlockBlobService = None
+
+from typing import Any, Dict, List, Optional
+
+from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
+from azure.identity import ClientSecretCredential
+from azure.storage.blob import BlobClient, BlobServiceClient, ContainerClient, StorageStreamDownloader
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
@@ -52,6 +48,8 @@ class WasbHook(BaseHook):
 
     :param wasb_conn_id: Reference to the wasb connection.
     :type wasb_conn_id: str
+    :param public_read: Whether an anonymous public read access should be used. default is False
+    :type public_read: bool
     """
 
     conn_name_attr = 'wasb_conn_id'
@@ -59,18 +57,69 @@ class WasbHook(BaseHook):
     conn_type = 'wasb'
     hook_name = 'Azure Blob Storage'
 
-    def __init__(self, wasb_conn_id: str = default_conn_name) -> None:
+    def __init__(self, wasb_conn_id: str = default_conn_name, public_read: bool = False) -> None:
         super().__init__()
         self.conn_id = wasb_conn_id
+        self.public_read = public_read
         self.connection = self.get_conn()
 
-    def get_conn(self) -> BlockBlobService:
-        """Return the BlockBlobService object."""
+    def get_conn(self) -> BlobServiceClient:  # pylint: disable=too-many-return-statements
+        """Return the BlobServiceClient object."""
         conn = self.get_connection(self.conn_id)
-        service_options = conn.extra_dejson
-        return BlockBlobService(account_name=conn.login, account_key=conn.password, **service_options)
+        extra = conn.extra_dejson or {}
+
+        if self.public_read:
+            # Here we use anonymous public read
+            # more info
+            # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources
+            return BlobServiceClient(account_url=conn.host)
+
+        if extra.get('connection_string'):
+            # connection_string auth takes priority
+            return BlobServiceClient.from_connection_string(extra.get('connection_string'))
+        if extra.get('shared_access_key'):
+            # using shared access key
+            return BlobServiceClient(account_url=conn.host, credential=extra.get('shared_access_key'))
+        if extra.get('tenant_id'):
+            # use Active Directory auth
+            app_id = conn.login
+            app_secret = conn.password
+            token_credential = ClientSecretCredential(extra.get('tenant_id'), app_id, app_secret)
+            return BlobServiceClient(account_url=conn.host, credential=token_credential)
+        sas_token = extra.get('sas_token')
+        if sas_token and sas_token.startswith('https'):
+            return BlobServiceClient(account_url=extra.get('sas_token'))
+        if sas_token and not sas_token.startswith('https'):
+            return BlobServiceClient(account_url=f"https://{conn.login}.blob.core.windows.net/" + sas_token)
+        else:
+            # Fall back to old auth
+            return BlobServiceClient(
+                account_url=f"https://{conn.login}.blob.core.windows.net/", credential=conn.password, **extra
+            )
+
+    def _get_container_client(self, container_name: str) -> ContainerClient:
+        """
+        Instantiates a container client
+
+        :param container_name: The name of the container
+        :type container_name: str
+        :return: ContainerClient
+        """
+        return self.connection.get_container_client(container_name)
 
-    def check_for_blob(self, container_name, blob_name, **kwargs):
+    def _get_blob_client(self, container_name: str, blob_name: str) -> BlobClient:
+        """
+        Instantiates a blob client
+
+        :param container_name: The name of the blob container
+        :type container_name: str
+        :param blob_name: The name of the blob. This needs not be existing
+        :type blob_name: str
+        """
+        container_client = self.create_container(container_name)
+        return container_client.get_blob_client(blob_name)
+
+    def check_for_blob(self, container_name: str, blob_name: str, **kwargs) -> bool:
         """
         Check if a blob exists on Azure Blob Storage.
 
@@ -78,15 +127,18 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param blob_name: Name of the blob.
         :type blob_name: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.exists()` takes.
+        :param kwargs: Optional keyword arguments for ``BlobClient.get_blob_properties`` takes.
         :type kwargs: object
         :return: True if the blob exists, False otherwise.
         :rtype: bool
         """
-        return self.connection.exists(container_name, blob_name, **kwargs)
+        try:
+            self._get_blob_client(container_name, blob_name).get_blob_properties(**kwargs)
+        except ResourceNotFoundError:
+            return False
+        return True
 
-    def check_for_prefix(self, container_name: str, prefix: str, **kwargs) -> bool:
+    def check_for_prefix(self, container_name: str, prefix: str, **kwargs):
         """
         Check if a prefix exists on Azure Blob storage.
 
@@ -94,31 +146,43 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param prefix: Prefix of the blob.
         :type prefix: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.list_blobs()` takes.
+        :param kwargs: Optional keyword arguments that ``ContainerClient.walk_blobs`` takes
         :type kwargs: object
         :return: True if blobs matching the prefix exist, False otherwise.
         :rtype: bool
         """
-        matches = self.connection.list_blobs(container_name, prefix, num_results=1, **kwargs)
-        return len(list(matches)) > 0
+        blobs = self.get_blobs_list(container_name=container_name, prefix=prefix, **kwargs)
+        return len(blobs) > 0
 
-    def get_blobs_list(self, container_name: str, prefix: str, **kwargs) -> list:
+    def get_blobs_list(
+        self,
+        container_name: str,
+        prefix: Optional[str] = None,
+        include: Optional[List[str]] = None,
+        delimiter: Optional[str] = '/',
+        **kwargs,
+    ) -> List:
         """
-        Return a list of blobs from path defined in prefix param
+        List blobs in a given container
 
-        :param container_name: Name of the container.
+        :param container_name: The name of the container
         :type container_name: str
-        :param prefix: Prefix of the blob.
+        :param prefix: Filters the results to return only blobs whose names
+            begin with the specified prefix.
         :type prefix: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.list_blobs()` takes (num_results, include,
-            delimiter, marker, timeout)
-        :type kwargs: object
-        :return: List of blobs.
-        :rtype: list(azure.storage.common.models.ListGenerator)
+        :param include: Specifies one or more additional datasets to include in the
+            response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
+            ``copy`, ``deleted``.
+        :type include: List[str]
+        :param delimiter: filters objects based on the delimiter (for e.g '.csv')
+        :type delimiter: str
         """
-        return self.connection.list_blobs(container_name, prefix, **kwargs)
+        container = self._get_container_client(container_name)
+        blob_list = []
+        blobs = container.walk_blobs(name_starts_with=prefix, include=include, delimiter=delimiter, **kwargs)
+        for blob in blobs:
+            blob_list.append(blob.name)
+        return blob_list
 
     def load_file(self, file_path: str, container_name: str, blob_name: str, **kwargs) -> None:
         """
@@ -130,12 +194,11 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param blob_name: Name of the blob.
         :type blob_name: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.create_blob_from_path()` takes.
+        :param kwargs: Optional keyword arguments that ``BlobClient.upload_blob()`` takes.
         :type kwargs: object
         """
-        # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_file.
-        self.connection.create_blob_from_path(container_name, blob_name, file_path, **kwargs)
+        with open(file_path, 'rb') as data:
+            self.upload(container_name=container_name, blob_name=blob_name, data=data, **kwargs)
 
     def load_string(self, string_data: str, container_name: str, blob_name: str, **kwargs) -> None:
         """
@@ -147,12 +210,11 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param blob_name: Name of the blob.
         :type blob_name: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.create_blob_from_text()` takes.
+        :param kwargs: Optional keyword arguments that ``BlobClient.upload()`` takes.
         :type kwargs: object
         """
         # Reorder the argument order from airflow.providers.amazon.aws.hooks.s3.load_string.
-        self.connection.create_blob_from_text(container_name, blob_name, string_data, **kwargs)
+        self.upload(container_name, blob_name, string_data, **kwargs)
 
     def get_file(self, file_path: str, container_name: str, blob_name: str, **kwargs):
         """
@@ -164,11 +226,12 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param blob_name: Name of the blob.
         :type blob_name: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.create_blob_from_path()` takes.
+        :param kwargs: Optional keyword arguments that `BlobClient.download_blob()` takes.
         :type kwargs: object
         """
-        return self.connection.get_blob_to_path(container_name, blob_name, file_path, **kwargs)
+        with open(file_path, "wb") as fileblob:
+            stream = self.download(container_name=container_name, blob_name=blob_name, **kwargs)
+            fileblob.write(stream.readall())
 
     def read_file(self, container_name: str, blob_name: str, **kwargs):
         """
@@ -178,11 +241,100 @@ class WasbHook(BaseHook):
         :type container_name: str
         :param blob_name: Name of the blob.
         :type blob_name: str
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.create_blob_from_path()` takes.
+        :param kwargs: Optional keyword arguments that `BlobClient.download_blob` takes.
         :type kwargs: object
         """
-        return self.connection.get_blob_to_text(container_name, blob_name, **kwargs).content
+        return self.download(container_name, blob_name, **kwargs).readall()
+
+    def upload(
+        self,
+        container_name,
+        blob_name,
+        data,
+        blob_type: str = 'BlockBlob',
+        length: Optional[int] = None,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        """
+        Creates a new blob from a data source with automatic chunking.
+
+        :param container_name: The name of the container to upload data
+        :type container_name: str
+        :param blob_name: The name of the blob to upload. This need not exist in the container
+        :type blob_name: str
+        :param data: The blob data to upload
+        :param blob_type: The type of the blob. This can be either ``BlockBlob``,
+            ``PageBlob`` or ``AppendBlob``. The default value is ``BlockBlob``.
+        :type blob_type: storage.BlobType
+        :param length: Number of bytes to read from the stream. This is optional,
+            but should be supplied for optimal performance.
+        :type length: int
+        """
+        blob_client = self._get_blob_client(container_name, blob_name)
+        return blob_client.upload_blob(data, blob_type, length=length, **kwargs)
+
+    def download(
+        self, container_name, blob_name, offset: Optional[int] = None, length: Optional[int] = None, **kwargs
+    ) -> StorageStreamDownloader:
+        """
+        Downloads a blob to the StorageStreamDownloader
+
+        :param container_name: The name of the container containing the blob
+        :type container_name: str
+        :param blob_name: The name of the blob to download
+        :type blob_name: str
+        :param offset: Start of byte range to use for downloading a section of the blob.
+            Must be set if length is provided.
+        :type offset: int
+        :param length: Number of bytes to read from the stream.
+        :type length: int
+        """
+        blob_client = self._get_blob_client(container_name, blob_name)
+        return blob_client.download_blob(offset=offset, length=length, **kwargs)
+
+    def create_container(self, container_name: str) -> ContainerClient:
+        """
+        Create container object if not already existing
+
+        :param container_name: The name of the container to create
+        :type container_name: str
+        """
+        container_client = self._get_container_client(container_name)
+        try:
+            self.log.info('Attempting to create container: %s', container_name)
+            container_client.create_container()
+            self.log.info("Created container: %s", container_name)
+            return container_client
+        except ResourceExistsError:
+            self.log.info("Container %s already exists", container_name)
+            return container_client
+
+    def delete_container(self, container_name: str) -> None:
+        """
+        Delete a container object
+
+        :param container_name: The name of the container
+        :type container_name: str
+        """
+        try:
+            self.log.info('Attempting to delete container: %s', container_name)
+            self._get_container_client(container_name).delete_container()
+            self.log.info('Deleted container: %s', container_name)
+        except ResourceNotFoundError:
+            self.log.info('Container %s not found', container_name)
+
+    def delete_blobs(self, container_name: str, *blobs, **kwargs) -> None:
+        """
+        Marks the specified blobs or snapshots for deletion.
+
+        :param container_name: The name of the container containing the blobs
+        :type container_name: str
+        :param blobs: The blobs to delete. This can be a single blob, or multiple values
+            can be supplied, where each value is either the name of the blob (str) or BlobProperties.
+        :type blobs: Union[str, BlobProperties]
+        """
+        self._get_container_client(container_name).delete_blobs(*blobs, **kwargs)
+        self.log.info("Deleted blobs: %s", blobs)
 
     def delete_file(
         self,
@@ -204,22 +356,16 @@ class WasbHook(BaseHook):
         :param ignore_if_missing: if True, then return success even if the
             blob does not exist.
         :type ignore_if_missing: bool
-        :param kwargs: Optional keyword arguments that
-            `BlockBlobService.create_blob_from_path()` takes.
+        :param kwargs: Optional keyword arguments that ``ContainerClient.delete_blobs()`` takes.
         :type kwargs: object
         """
         if is_prefix:
-            blobs_to_delete = [
-                blob.name for blob in self.connection.list_blobs(container_name, prefix=blob_name, **kwargs)
-            ]
+            blobs_to_delete = self.get_blobs_list(container_name, prefix=blob_name, **kwargs)
         elif self.check_for_blob(container_name, blob_name):
             blobs_to_delete = [blob_name]
         else:
             blobs_to_delete = []
-
         if not ignore_if_missing and len(blobs_to_delete) == 0:
             raise AirflowException(f'Blob(s) not found: {blob_name}')
 
-        for blob_uri in blobs_to_delete:
-            self.log.info("Deleting blob: %s", blob_uri)
-            self.connection.delete_blob(container_name, blob_uri, delete_snapshots='include', **kwargs)
+        self.delete_blobs(container_name, *blobs_to_delete, **kwargs)
diff --git a/breeze b/breeze
index 8732405..8204184 100755
--- a/breeze
+++ b/breeze
@@ -2411,7 +2411,7 @@ function breeze::flag_packages() {
 
 ${FORMATTED_PACKAGE_FORMATS}
 
-        Default: ${_breeze_default_package_formats:=}
+        Default: ${_breeze_default_package_format:=}
 
 "
 }
diff --git a/scripts/docker/install_airflow.sh b/scripts/docker/install_airflow.sh
index cf6c8ed..91f3860 100755
--- a/scripts/docker/install_airflow.sh
+++ b/scripts/docker/install_airflow.sh
@@ -63,6 +63,9 @@ function install_airflow() {
             pip install ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
                 "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_INSTALL_VERSION}"
         fi
+        # Work around to install azure-storage-blob
+        pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+        pip install azure-storage-blob azure-storage-file
         # make sure correct PIP version is used
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
         pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
@@ -79,6 +82,9 @@ function install_airflow() {
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade --upgrade-strategy only-if-needed \
             ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
             "${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_INSTALL_VERSION}" \
+        # Work around to install azure-storage-blob
+        pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+        pip install azure-storage-blob azure-storage-file
         # make sure correct PIP version is used
         pip install ${AIRFLOW_INSTALL_USER_FLAG} --upgrade "pip==${AIRFLOW_PIP_VERSION}"
         pip check || ${CONTINUE_ON_PIP_CHECK_FAILURE}
diff --git a/scripts/in_container/_in_container_utils.sh b/scripts/in_container/_in_container_utils.sh
index d18cd33..d291032 100644
--- a/scripts/in_container/_in_container_utils.sh
+++ b/scripts/in_container/_in_container_utils.sh
@@ -310,7 +310,8 @@ function reinstall_azure_storage_blob() {
     echo
     echo "Reinstalling azure-storage-blob"
     echo
-    pip install azure-storage-blob --no-deps --force-reinstall
+    pip uninstall azure-storage azure-storage-blob azure-storage-file --yes
+    pip install azure-storage-blob azure-storage-file --no-deps --force-reinstall
     group_end
 }
 
diff --git a/setup.py b/setup.py
index 380fde7..ba28546 100644
--- a/setup.py
+++ b/setup.py
@@ -218,7 +218,7 @@ azure = [
     'azure-mgmt-containerinstance>=1.5.0,<2.0',
     'azure-mgmt-datalake-store>=0.5.0',
     'azure-mgmt-resource>=2.2.0',
-    'azure-storage>=0.34.0, <0.37.0',
+    'azure-storage-file>=2.1.0',
 ]
 cassandra = [
     'cassandra-driver>=3.13.0,<3.21.0',
diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py
index 16556a1..65c2401 100644
--- a/tests/providers/microsoft/azure/hooks/test_wasb.py
+++ b/tests/providers/microsoft/azure/hooks/test_wasb.py
@@ -20,155 +20,245 @@
 
 import json
 import unittest
-from collections import namedtuple
 from unittest import mock
 
 import pytest
+from azure.storage.blob import BlobServiceClient
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
 from airflow.utils import db
 
-try:
-    from azure.storage.blob import BlockBlobService
-except ImportError:
-    # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
-    # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
-    # fine together with `snowflake` because the deprecated library does not overlap with the
-    # new libraries except the `blob` classes. So while `azure` works fine for most cases
-    # blob is the only exception
-    # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
-    # Once this is merged, we can remove the xfail below and this ImportError handling
-    BlockBlobService = None
+# connection_string has a format
+CONN_STRING = (
+    'DefaultEndpointsProtocol=https;AccountName=testname;AccountKey=wK7BOz;EndpointSuffix=core.windows.net'
+)
+
+ACCESS_KEY_STRING = "AccountName=name;skdkskd"
 
 
-@pytest.mark.xfail
 class TestWasbHook(unittest.TestCase):
     def setUp(self):
         db.merge_conn(Connection(conn_id='wasb_test_key', conn_type='wasb', login='login', password='key'))
+        self.connection_type = 'wasb'
+        self.connection_string_id = 'azure_test_connection_string'
+        self.shared_key_conn_id = 'azure_shared_key_test'
+        self.ad_conn_id = 'azure_AD_test'
+        self.sas_conn_id = 'sas_token_id'
+        self.public_read_conn_id = 'pub_read_id'
+
+        db.merge_conn(
+            Connection(
+                conn_id=self.public_read_conn_id,
+                conn_type=self.connection_type,
+                host='https://accountname.blob.core.windows.net',
+            )
+        )
+
+        db.merge_conn(
+            Connection(
+                conn_id=self.connection_string_id,
+                conn_type=self.connection_type,
+                extra=json.dumps({'connection_string': CONN_STRING}),
+            )
+        )
         db.merge_conn(
             Connection(
-                conn_id='wasb_test_sas_token',
-                conn_type='wasb',
-                login='login',
+                conn_id=self.shared_key_conn_id,
+                conn_type=self.connection_type,
+                host='https://accountname.blob.core.windows.net',
+                extra=json.dumps({'shared_access_key': 'token'}),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=self.ad_conn_id,
+                conn_type=self.connection_type,
+                extra=json.dumps(
+                    {'tenant_id': 'token', 'application_id': 'appID', 'application_secret': "appsecret"}
+                ),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=self.sas_conn_id,
+                conn_type=self.connection_type,
                 extra=json.dumps({'sas_token': 'token'}),
             )
         )
 
     def test_key(self):
         hook = WasbHook(wasb_conn_id='wasb_test_key')
-        assert hook.conn_id == 'wasb_test_key'
-        assert isinstance(hook.connection, BlockBlobService)
+        self.assertEqual(hook.conn_id, 'wasb_test_key')
+        self.assertIsInstance(hook.connection, BlobServiceClient)
+
+    def test_public_read(self):
+        hook = WasbHook(wasb_conn_id=self.public_read_conn_id, public_read=True)
+        assert isinstance(hook.get_conn(), BlobServiceClient)
+
+    def test_connection_string(self):
+        hook = WasbHook(wasb_conn_id=self.connection_string_id)
+        assert hook.conn_id == self.connection_string_id
+        assert isinstance(hook.get_conn(), BlobServiceClient)
 
-    def test_sas_token(self):
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        assert hook.conn_id == 'wasb_test_sas_token'
-        assert isinstance(hook.connection, BlockBlobService)
+    def test_shared_key_connection(self):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        assert isinstance(hook.get_conn(), BlobServiceClient)
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
+    def test_sas_token_connection(self):
+        hook = WasbHook(wasb_conn_id=self.sas_conn_id)
+        assert isinstance(hook.get_conn(), BlobServiceClient)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     def test_check_for_blob(self, mock_service):
-        mock_instance = mock_service.return_value
-        mock_instance.exists.return_value = True
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        assert hook.check_for_blob('container', 'blob', timeout=3)
-        mock_instance.exists.assert_called_once_with('container', 'blob', timeout=3)
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_check_for_blob_empty(self, mock_service):
-        mock_service.return_value.exists.return_value = False
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        assert not hook.check_for_blob('container', 'blob')
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_check_for_prefix(self, mock_service):
-        mock_instance = mock_service.return_value
-        mock_instance.list_blobs.return_value = iter(['blob_1'])
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        assert hook.check_for_blob(container_name='mycontainer', blob_name='myblob')
+        mock_container_client = mock_service.return_value.get_container_client
+        mock_container_client.assert_called_once_with('mycontainer')
+        mock_container_client.return_value.get_blob_client.assert_called_once_with('myblob')
+        mock_container_client.return_value.get_blob_client.return_value.get_blob_properties.assert_called()
+
+    @mock.patch.object(WasbHook, 'get_blobs_list')
+    def test_check_for_prefix(self, get_blobs_list):
+        get_blobs_list.return_value = ['blobs']
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
         assert hook.check_for_prefix('container', 'prefix', timeout=3)
-        mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_check_for_prefix_empty(self, mock_service):
-        mock_instance = mock_service.return_value
-        mock_instance.list_blobs.return_value = iter([])
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        assert not hook.check_for_prefix('container', 'prefix')
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_load_file(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        hook.load_file('path', 'container', 'blob', max_connections=1)
-        mock_instance.create_blob_from_path.assert_called_once_with(
-            'container', 'blob', 'path', max_connections=1
+        get_blobs_list.assert_called_once_with(container_name='container', prefix='prefix', timeout=3)
+
+    @mock.patch.object(WasbHook, 'get_blobs_list')
+    def test_check_for_prefix_empty(self, get_blobs_list):
+        get_blobs_list.return_value = []
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        assert not hook.check_for_prefix('container', 'prefix', timeout=3)
+        get_blobs_list.assert_called_once_with(container_name='container', prefix='prefix', timeout=3)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_get_blobs_list(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook.get_blobs_list(container_name='mycontainer', prefix='my', include=None, delimiter='/')
+        mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+        mock_service.return_value.get_container_client.return_value.walk_blobs.assert_called_once_with(
+            name_starts_with='my', include=None, delimiter='/'
         )
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_load_string(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+    @mock.patch.object(WasbHook, 'upload')
+    def test_load_file(self, mock_upload):
+        with mock.patch("builtins.open", mock.mock_open(read_data="data")):
+            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook.load_file('path', 'container', 'blob', max_connections=1)
+        mock_upload.assert_called()
+
+    @mock.patch.object(WasbHook, 'upload')
+    def test_load_string(self, mock_upload):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
         hook.load_string('big string', 'container', 'blob', max_connections=1)
-        mock_instance.create_blob_from_text.assert_called_once_with(
-            'container', 'blob', 'big string', max_connections=1
-        )
+        mock_upload.assert_called_once_with('container', 'blob', 'big string', max_connections=1)
+
+    @mock.patch.object(WasbHook, 'download')
+    def test_get_file(self, mock_download):
+        with mock.patch("builtins.open", mock.mock_open(read_data="data")):
+            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+            hook.get_file('path', 'container', 'blob', max_connections=1)
+        mock_download.assert_called_once_with(container_name='container', blob_name='blob', max_connections=1)
+        mock_download.return_value.readall.assert_called()
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_get_file(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        hook.get_file('path', 'container', 'blob', max_connections=1)
-        mock_instance.get_blob_to_path.assert_called_once_with('container', 'blob', 'path', max_connections=1)
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_read_file(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    @mock.patch.object(WasbHook, 'download')
+    def test_read_file(self, mock_download, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
         hook.read_file('container', 'blob', max_connections=1)
-        mock_instance.get_blob_to_text.assert_called_once_with('container', 'blob', max_connections=1)
+        mock_download.assert_called_once_with('container', 'blob', max_connections=1)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_upload(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook.upload(
+            container_name='mycontainer', blob_name='myblob', data=b'mydata', blob_type='BlockBlob', length=4
+        )
+        mock_cn_client = mock_service.return_value.get_container_client
+        mock_cn_client.assert_called_once_with('mycontainer')
+        mock_cn_client.return_value.get_blob_client.assert_called_once_with('myblob')
+        mock_cn_client.return_value.get_blob_client.return_value.upload_blob.assert_called_once_with(
+            b'mydata', 'BlockBlob', length=4
+        )
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_download(self, mock_service):
+        container_client = mock_service.return_value.get_container_client
+        blob_client = container_client.return_value.get_blob_client
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook.download(container_name='mycontainer', blob_name='myblob', offset=2, length=4)
+        container_client.assert_called_once_with('mycontainer')
+        blob_client.assert_called_once_with('myblob')
+        blob_client.return_value.download_blob.assert_called_once_with(offset=2, length=4)
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_get_container_client(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook._get_container_client('mycontainer')
+        mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_get_blob_client(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook._get_blob_client(container_name='mycontainer', blob_name='myblob')
+        mock_instance = mock_service.return_value.get_container_client
+        mock_instance.assert_called_once_with('mycontainer')
+        mock_instance.return_value.get_blob_client.assert_called_once_with('myblob')
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_create_container(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook.create_container(container_name='mycontainer')
+        mock_instance = mock_service.return_value.get_container_client
+        mock_instance.assert_called_once_with('mycontainer')
+        mock_instance.return_value.create_container.assert_called()
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_delete_single_blob(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    def test_delete_container(self, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
+        hook.delete_container('mycontainer')
+        mock_service.return_value.get_container_client.assert_called_once_with('mycontainer')
+        mock_service.return_value.get_container_client.return_value.delete_container.assert_called()
+
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    @mock.patch.object(WasbHook, 'delete_blobs')
+    def test_delete_single_blob(self, delete_blobs, mock_service):
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
         hook.delete_file('container', 'blob', is_prefix=False)
-        mock_instance.delete_blob.assert_called_once_with('container', 'blob', delete_snapshots='include')
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_delete_multiple_blobs(self, mock_service):
-        mock_instance = mock_service.return_value
-        Blob = namedtuple('Blob', ['name'])
-        mock_instance.list_blobs.return_value = iter([Blob('blob_prefix/blob1'), Blob('blob_prefix/blob2')])
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+        delete_blobs.assert_called_once_with('container', 'blob')
+
+    @mock.patch.object(WasbHook, 'delete_blobs')
+    @mock.patch.object(WasbHook, 'get_blobs_list')
+    @mock.patch.object(WasbHook, 'check_for_blob')
+    def test_delete_multiple_blobs(self, mock_check, mock_get_blobslist, mock_delete_blobs):
+        mock_check.return_value = False
+        mock_get_blobslist.return_value = ['blob_prefix/blob1', 'blob_prefix/blob2']
+        hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
         hook.delete_file('container', 'blob_prefix', is_prefix=True)
-        mock_instance.delete_blob.assert_any_call(
-            'container', 'blob_prefix/blob1', delete_snapshots='include'
-        )
-        mock_instance.delete_blob.assert_any_call(
-            'container', 'blob_prefix/blob2', delete_snapshots='include'
+        mock_get_blobslist.assert_called_once_with('container', prefix='blob_prefix')
+        mock_delete_blobs.assert_any_call(
+            'container',
+            'blob_prefix/blob1',
+            'blob_prefix/blob2',
         )
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_delete_nonexisting_blob_fails(self, mock_service):
-        mock_instance = mock_service.return_value
-        mock_instance.exists.return_value = False
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
+    @mock.patch.object(WasbHook, 'get_blobs_list')
+    @mock.patch.object(WasbHook, 'check_for_blob')
+    def test_delete_nonexisting_blob_fails(self, mock_check, mock_getblobs, mock_service):
+        mock_getblobs.return_value = []
+        mock_check.return_value = False
         with pytest.raises(Exception) as ctx:
+            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
             hook.delete_file('container', 'nonexisting_blob', is_prefix=False, ignore_if_missing=False)
         assert isinstance(ctx.value, AirflowException)
 
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_delete_multiple_nonexisting_blobs_fails(self, mock_service):
-        mock_instance = mock_service.return_value
-        mock_instance.list_blobs.return_value = iter([])
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
+    @mock.patch.object(WasbHook, 'get_blobs_list')
+    def test_delete_multiple_nonexisting_blobs_fails(self, mock_getblobs):
+        mock_getblobs.return_value = []
         with pytest.raises(Exception) as ctx:
+            hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)
             hook.delete_file('container', 'nonexisting_blob_prefix', is_prefix=True, ignore_if_missing=False)
         assert isinstance(ctx.value, AirflowException)
-
-    @mock.patch('airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService', autospec=True)
-    def test_get_blobs_list(self, mock_service):
-        mock_instance = mock_service.return_value
-        hook = WasbHook(wasb_conn_id='wasb_test_sas_token')
-        hook.get_blobs_list('container', 'prefix', num_results=1, timeout=3)
-        mock_instance.list_blobs.assert_called_once_with('container', 'prefix', num_results=1, timeout=3)
diff --git a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
index f859649..8e517e0 100644
--- a/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
+++ b/tests/providers/microsoft/azure/log/test_wasb_task_handler.py
@@ -19,7 +19,6 @@ import unittest
 from datetime import datetime
 from unittest import mock
 
-import pytest
 from azure.common import AzureHttpError
 
 from airflow.models import DAG, TaskInstance
@@ -55,7 +54,7 @@ class TestWasbTaskHandler(unittest.TestCase):
         self.addCleanup(self.dag.clear)
 
     @conf_vars({('logging', 'remote_log_conn_id'): 'wasb_default'})
-    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlockBlobService")
+    @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient")
     def test_hook(self, mock_service):
         assert isinstance(self.wasb_task_handler.hook, WasbHook)
 
@@ -86,14 +85,6 @@ class TestWasbTaskHandler(unittest.TestCase):
         self.wasb_task_handler.set_context(self.ti)
         assert self.wasb_task_handler.upload_on_close
 
-    # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
-    # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
-    # fine together with `snowflake` because the deprecated library does not overlap with the
-    # new libraries except the `blob` classes. So while `azure` works fine for most cases
-    # blob is the only exception
-    # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
-    # Once this is merged, we can remove the xfail
-    @pytest.mark.xfail
     @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook")
     def test_wasb_log_exists(self, mock_hook):
         instance = mock_hook.return_value
@@ -103,14 +94,6 @@ class TestWasbTaskHandler(unittest.TestCase):
             self.container_name, self.remote_log_location
         )
 
-    # The `azure` provider uses legacy `azure-storage` library, where `snowflake` uses the
-    # newer and more stable versions of those libraries. Most of `azure` operators and hooks work
-    # fine together with `snowflake` because the deprecated library does not overlap with the
-    # new libraries except the `blob` classes. So while `azure` works fine for most cases
-    # blob is the only exception
-    # Solution to that is being worked on in https://github.com/apache/airflow/pull/12188
-    # Once this is merged, we can remove the xfail
-    @pytest.mark.xfail
     @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.WasbHook")
     def test_wasb_read(self, mock_hook):
         mock_hook.return_value.read_file.return_value = 'Log line'