You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/01/09 21:31:29 UTC
[airflow] branch main updated: Fix mypy in providers/aws/hooks (#20353)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 341b461 Fix mypy in providers/aws/hooks (#20353)
341b461 is described below
commit 341b461e4fbd9ae5961ef9448c8f08e1686ee5e4
Author: Kanthi <su...@gmail.com>
AuthorDate: Sun Jan 9 16:30:59 2022 -0500
Fix mypy in providers/aws/hooks (#20353)
---
airflow/providers/amazon/aws/hooks/athena.py | 2 ++
airflow/providers/amazon/aws/hooks/base_aws.py | 23 ++++++++++++-----------
airflow/providers/amazon/aws/hooks/ec2.py | 9 +++++----
airflow/providers/amazon/aws/hooks/eks.py | 5 ++++-
airflow/providers/amazon/aws/hooks/emr.py | 2 +-
5 files changed, 24 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index c7a91a5..39ebe62 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -247,6 +247,8 @@ class AthenaHook(AwsBaseHook):
except KeyError:
self.log.error("Error retrieving OutputLocation")
raise
+ else:
+ raise
else:
raise ValueError("Invalid Query execution id")
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 8f9db32..2d004e5 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -61,6 +61,7 @@ class _SessionFactory(LoggingMixin):
self.region_name = region_name
self.config = config
self.extra_config = self.conn.extra_dejson
+
self.basic_session: Optional[boto3.session.Session] = None
self.role_arn: Optional[str] = None
@@ -129,30 +130,30 @@ class _SessionFactory(LoggingMixin):
)
session = botocore.session.get_session()
session._credentials = credentials
+
if self.basic_session is None:
raise RuntimeError("The basic session should be created here!")
+
region_name = self.basic_session.region_name
session.set_config_variable("region", region_name)
+
return boto3.session.Session(botocore_session=session, **session_kwargs)
def _refresh_credentials(self) -> Dict[str, Any]:
self.log.info('Refreshing credentials')
assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
sts_session = self.basic_session
+
+ if sts_session is None:
+ raise RuntimeError(
+ "Session should be initialized when refresh credentials with assume_role is used!"
+ )
+
+ sts_client = sts_session.client("sts", config=self.config)
+
if assume_role_method == 'assume_role':
- if sts_session is None:
- raise RuntimeError(
- "Session should be initialized when refresh credentials with assume_role is used!"
- )
- sts_client = sts_session.client("sts", config=self.config)
sts_response = self._assume_role(sts_client=sts_client)
elif assume_role_method == 'assume_role_with_saml':
- if sts_session is None:
- raise RuntimeError(
- "Session should be initialized when refresh "
- "credentials with assume_role_with_saml is used!"
- )
- sts_client = sts_session.client("sts", config=self.config)
sts_response = self._assume_role_with_saml(sts_client=sts_client)
else:
raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py
index 5dba964..9d97292 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -19,6 +19,7 @@
import functools
import time
+from typing import List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -69,7 +70,7 @@ class EC2Hook(AwsBaseHook):
super().__init__(*args, **kwargs)
- def get_instance(self, instance_id: str, filters: list = None):
+ def get_instance(self, instance_id: str, filters: Optional[List] = None):
"""
Get EC2 instance by id and return it.
@@ -122,7 +123,7 @@ class EC2Hook(AwsBaseHook):
return self.conn.terminate_instances(InstanceIds=instance_ids)
@only_client_type
- def describe_instances(self, filters: list = None, instance_ids: list = None):
+ def describe_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None):
"""
Describe EC2 instances, optionally applying filters and selective instance ids
@@ -139,7 +140,7 @@ class EC2Hook(AwsBaseHook):
return self.conn.describe_instances(Filters=filters, InstanceIds=instance_ids)
@only_client_type
- def get_instances(self, filters: list = None, instance_ids: list = None) -> list:
+ def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None) -> list:
"""
Get list of instance details, optionally applying filters and selective instance ids
@@ -154,7 +155,7 @@ class EC2Hook(AwsBaseHook):
]
@only_client_type
- def get_instance_ids(self, filters: list = None) -> list:
+ def get_instance_ids(self, filters: Optional[List] = None) -> list:
"""
Get list of instance ids, optionally applying filters to fetch selective instances
diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py
index f34e98d..63f934e 100644
--- a/airflow/providers/amazon/aws/hooks/eks.py
+++ b/airflow/providers/amazon/aws/hooks/eks.py
@@ -366,6 +366,7 @@ class EksHook(AwsBaseHook):
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
return ClusterStates.NONEXISTENT
+ raise
def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) -> FargateProfileStates:
"""
@@ -392,6 +393,7 @@ class EksHook(AwsBaseHook):
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
return FargateProfileStates.NONEXISTENT
+ raise
def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> NodegroupStates:
"""
@@ -416,6 +418,7 @@ class EksHook(AwsBaseHook):
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
return NodegroupStates.NONEXISTENT
+ raise
def list_clusters(
self,
@@ -493,7 +496,7 @@ class EksHook(AwsBaseHook):
:return: A List of the combined results of the provided API call.
:rtype: List
"""
- name_collection = []
+ name_collection: List = []
token = DEFAULT_PAGINATION_TOKEN
while token is not None:
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index bfec7ec..22e3d19 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -118,7 +118,7 @@ class EmrContainerHook(AwsBaseHook):
)
SUCCESS_STATES = ("COMPLETED",)
- def __init__(self, *args: Any, virtual_cluster_id: str = None, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, virtual_cluster_id: Optional[str] = None, **kwargs: Any) -> None:
super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
self.virtual_cluster_id = virtual_cluster_id