You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/01/09 21:31:29 UTC

[airflow] branch main updated: Fix mypy in providers/aws/hooks (#20353)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 341b461  Fix mypy in providers/aws/hooks (#20353)
341b461 is described below

commit 341b461e4fbd9ae5961ef9448c8f08e1686ee5e4
Author: Kanthi <su...@gmail.com>
AuthorDate: Sun Jan 9 16:30:59 2022 -0500

    Fix mypy in providers/aws/hooks (#20353)
---
 airflow/providers/amazon/aws/hooks/athena.py   |  2 ++
 airflow/providers/amazon/aws/hooks/base_aws.py | 23 ++++++++++++-----------
 airflow/providers/amazon/aws/hooks/ec2.py      |  9 +++++----
 airflow/providers/amazon/aws/hooks/eks.py      |  5 ++++-
 airflow/providers/amazon/aws/hooks/emr.py      |  2 +-
 5 files changed, 24 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index c7a91a5..39ebe62 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -247,6 +247,8 @@ class AthenaHook(AwsBaseHook):
                 except KeyError:
                     self.log.error("Error retrieving OutputLocation")
                     raise
+            else:
+                raise
         else:
             raise ValueError("Invalid Query execution id")
 
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 8f9db32..2d004e5 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -61,6 +61,7 @@ class _SessionFactory(LoggingMixin):
         self.region_name = region_name
         self.config = config
         self.extra_config = self.conn.extra_dejson
+
         self.basic_session: Optional[boto3.session.Session] = None
         self.role_arn: Optional[str] = None
 
@@ -129,30 +130,30 @@ class _SessionFactory(LoggingMixin):
             )
         session = botocore.session.get_session()
         session._credentials = credentials
+
         if self.basic_session is None:
             raise RuntimeError("The basic session should be created here!")
+
         region_name = self.basic_session.region_name
         session.set_config_variable("region", region_name)
+
         return boto3.session.Session(botocore_session=session, **session_kwargs)
 
     def _refresh_credentials(self) -> Dict[str, Any]:
         self.log.info('Refreshing credentials')
         assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
         sts_session = self.basic_session
+
+        if sts_session is None:
+            raise RuntimeError(
+                "Session should be initialized when refresh credentials with assume_role is used!"
+            )
+
+        sts_client = sts_session.client("sts", config=self.config)
+
         if assume_role_method == 'assume_role':
-            if sts_session is None:
-                raise RuntimeError(
-                    "Session should be initialized when refresh credentials with assume_role is used!"
-                )
-            sts_client = sts_session.client("sts", config=self.config)
             sts_response = self._assume_role(sts_client=sts_client)
         elif assume_role_method == 'assume_role_with_saml':
-            if sts_session is None:
-                raise RuntimeError(
-                    "Session should be initialized when refresh "
-                    "credentials with assume_role_with_saml is used!"
-                )
-            sts_client = sts_session.client("sts", config=self.config)
             sts_response = self._assume_role_with_saml(sts_client=sts_client)
         else:
             raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py
index 5dba964..9d97292 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -19,6 +19,7 @@
 
 import functools
 import time
+from typing import List, Optional
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -69,7 +70,7 @@ class EC2Hook(AwsBaseHook):
 
         super().__init__(*args, **kwargs)
 
-    def get_instance(self, instance_id: str, filters: list = None):
+    def get_instance(self, instance_id: str, filters: Optional[List] = None):
         """
         Get EC2 instance by id and return it.
 
@@ -122,7 +123,7 @@ class EC2Hook(AwsBaseHook):
         return self.conn.terminate_instances(InstanceIds=instance_ids)
 
     @only_client_type
-    def describe_instances(self, filters: list = None, instance_ids: list = None):
+    def describe_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None):
         """
         Describe EC2 instances, optionally applying filters and selective instance ids
 
@@ -139,7 +140,7 @@ class EC2Hook(AwsBaseHook):
         return self.conn.describe_instances(Filters=filters, InstanceIds=instance_ids)
 
     @only_client_type
-    def get_instances(self, filters: list = None, instance_ids: list = None) -> list:
+    def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None) -> list:
         """
         Get list of instance details, optionally applying filters and selective instance ids
 
@@ -154,7 +155,7 @@ class EC2Hook(AwsBaseHook):
         ]
 
     @only_client_type
-    def get_instance_ids(self, filters: list = None) -> list:
+    def get_instance_ids(self, filters: Optional[List] = None) -> list:
         """
         Get list of instance ids, optionally applying filters to fetch selective instances
 
diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py
index f34e98d..63f934e 100644
--- a/airflow/providers/amazon/aws/hooks/eks.py
+++ b/airflow/providers/amazon/aws/hooks/eks.py
@@ -366,6 +366,7 @@ class EksHook(AwsBaseHook):
         except ClientError as ex:
             if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
                 return ClusterStates.NONEXISTENT
+            raise
 
     def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) -> FargateProfileStates:
         """
@@ -392,6 +393,7 @@ class EksHook(AwsBaseHook):
         except ClientError as ex:
             if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
                 return FargateProfileStates.NONEXISTENT
+            raise
 
     def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> NodegroupStates:
         """
@@ -416,6 +418,7 @@ class EksHook(AwsBaseHook):
         except ClientError as ex:
             if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
                 return NodegroupStates.NONEXISTENT
+            raise
 
     def list_clusters(
         self,
@@ -493,7 +496,7 @@ class EksHook(AwsBaseHook):
         :return: A List of the combined results of the provided API call.
         :rtype: List
         """
-        name_collection = []
+        name_collection: List = []
         token = DEFAULT_PAGINATION_TOKEN
 
         while token is not None:
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index bfec7ec..22e3d19 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -118,7 +118,7 @@ class EmrContainerHook(AwsBaseHook):
     )
     SUCCESS_STATES = ("COMPLETED",)
 
-    def __init__(self, *args: Any, virtual_cluster_id: str = None, **kwargs: Any) -> None:
+    def __init__(self, *args: Any, virtual_cluster_id: Optional[str] = None, **kwargs: Any) -> None:
         super().__init__(client_type="emr-containers", *args, **kwargs)  # type: ignore
         self.virtual_cluster_id = virtual_cluster_id