You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/12/20 08:05:14 UTC
[airflow] branch main updated: Add link for EMR Steps Sensor logs (#28180)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fefcb1d567 Add link for EMR Steps Sensor logs (#28180)
fefcb1d567 is described below
commit fefcb1d567d8d605f7ec9b7d408831d656736541
Author: Syed Hussaain <10...@users.noreply.github.com>
AuthorDate: Tue Dec 20 00:05:05 2022 -0800
Add link for EMR Steps Sensor logs (#28180)
---
airflow/providers/amazon/aws/hooks/s3.py | 2 +-
airflow/providers/amazon/aws/links/emr.py | 8 +++++++
airflow/providers/amazon/aws/sensors/emr.py | 26 ++++++++++++++++------
airflow/providers/amazon/provider.yaml | 1 +
tests/providers/amazon/aws/hooks/test_s3.py | 8 +++++++
.../providers/amazon/aws/sensors/test_emr_base.py | 3 ++-
.../amazon/aws/sensors/test_emr_job_flow.py | 9 +++++---
7 files changed, 45 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index e5b1bad580..1e2235f24a 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -152,7 +152,7 @@ class S3Hook(AwsBaseHook):
:return: the parsed bucket name and key
"""
format = s3url.split("//")
- if format[0].lower() == "s3:":
+ if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py
index aa739567fb..83e190f663 100644
--- a/airflow/providers/amazon/aws/links/emr.py
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -27,3 +27,11 @@ class EmrClusterLink(BaseAwsLink):
format_str = (
BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
)
+
+
+class EmrLogsLink(BaseAwsLink):
+ """Helper class for constructing AWS EMR Logs Link"""
+
+ name = "EMR Cluster Logs"
+ key = "emr_logs"
+ format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index 04e2897716..d1cd0949e0 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.links.emr import EmrLogsLink
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
if TYPE_CHECKING:
@@ -61,7 +63,7 @@ class EmrBaseSensor(BaseSensorOperator):
return self.hook
def poke(self, context: Context):
- response = self.get_emr_response()
+ response = self.get_emr_response(context=context)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
self.log.info("Bad HTTP response: %s", response)
@@ -78,7 +80,7 @@ class EmrBaseSensor(BaseSensorOperator):
return False
- def get_emr_response(self) -> dict[str, Any]:
+ def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get response.
@@ -329,7 +331,7 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
self.target_states = target_states or self.COMPLETED_STATES
self.failed_states = failed_states or self.FAILURE_STATES
- def get_emr_response(self) -> dict[str, Any]:
+ def get_emr_response(self, context: Context) -> dict[str, Any]:
emr_client = self.get_hook().get_conn()
self.log.info("Poking notebook %s", self.notebook_execution_id)
@@ -382,6 +384,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
+ operator_extra_links = (EmrLogsLink(),)
def __init__(
self,
@@ -396,7 +399,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
self.target_states = target_states or ["TERMINATED"]
self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
- def get_emr_response(self) -> dict[str, Any]:
+ def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get cluster-level details.
@@ -406,9 +409,18 @@ class EmrJobFlowSensor(EmrBaseSensor):
:return: response
"""
emr_client = self.get_hook().get_conn()
-
self.log.info("Poking cluster %s", self.job_flow_id)
- return emr_client.describe_cluster(ClusterId=self.job_flow_id)
+ response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
+ log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
+ EmrLogsLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.get_hook().conn_region_name,
+ aws_partition=self.get_hook().conn_partition,
+ job_flow_id=self.job_flow_id,
+ log_uri="/".join(log_uri),
+ )
+ return response
@staticmethod
def state_from_response(response: dict[str, Any]) -> str:
@@ -476,7 +488,7 @@ class EmrStepSensor(EmrBaseSensor):
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
- def get_emr_response(self) -> dict[str, Any]:
+ def get_emr_response(self, context: Context) -> dict[str, Any]:
"""
Make an API call with boto3 and get details about the cluster step.
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 3f7b793a1c..230bc135ae 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -558,6 +558,7 @@ extra-links:
- airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
- airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
- airflow.providers.amazon.aws.links.emr.EmrClusterLink
+ - airflow.providers.amazon.aws.links.emr.EmrLogsLink
- airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
connection-types:
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index 48cfb9912f..ff231482da 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -71,6 +71,14 @@ class TestAwsS3Hook:
parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt")
assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
+ def test_parse_s3_url_s3a_style(self):
+ parsed = S3Hook.parse_s3_url("s3a://test/this/is/not/a-real-key.txt")
+ assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
+
+ def test_parse_s3_url_s3n_style(self):
+ parsed = S3Hook.parse_s3_url("s3n://test/this/is/not/a-real-key.txt")
+ assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
+
def test_parse_s3_url_path_style(self):
parsed = S3Hook.parse_s3_url("https://s3.us-west-2.amazonaws.com/DOC-EXAMPLE-BUCKET1/test.jpg")
assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.jpg"), "Incorrect parsing of the s3 url"
diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py
index b0dfd66233..f6b4351833 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_base.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_base.py
@@ -21,6 +21,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor
+from airflow.utils.context import Context
TARGET_STATE = "TARGET_STATE"
FAILED_STATE = "FAILED_STATE"
@@ -40,7 +41,7 @@ class EmrBaseSensorSubclass(EmrBaseSensor):
self.failed_states = [FAILED_STATE]
self.response = {} # will be set in tests
- def get_emr_response(self):
+ def get_emr_response(self, context: Context):
return self.response
@staticmethod
diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
index 87a80d6a01..a10e68abb8 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
@@ -199,6 +199,9 @@ class TestEmrJobFlowSensor:
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
+ # Mock context used in execute function
+ self.mock_ctx = MagicMock()
+
def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self):
self.mock_emr_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN,
@@ -210,7 +213,7 @@ class TestEmrJobFlowSensor:
task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
)
- operator.execute(None)
+ operator.execute(self.mock_ctx)
# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3
@@ -230,7 +233,7 @@ class TestEmrJobFlowSensor:
)
with pytest.raises(AirflowException):
- operator.execute(None)
+ operator.execute(self.mock_ctx)
# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 2
@@ -256,7 +259,7 @@ class TestEmrJobFlowSensor:
target_states=["RUNNING", "WAITING"],
)
- operator.execute(None)
+ operator.execute(self.mock_ctx)
# make sure we called twice
assert self.mock_emr_client.describe_cluster.call_count == 3