You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/25 16:40:17 UTC
[airflow] branch master updated: Enable Black on Providers Packages
(#10543)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new fdd9b6f Enable Black on Providers Packages (#10543)
fdd9b6f is described below
commit fdd9b6f65b608c516b8a062b058972d9a45ec9e3
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Tue Aug 25 17:39:04 2020 +0100
Enable Black on Providers Packages (#10543)
---
.flake8 | 2 +-
.pre-commit-config.yaml | 5 +-
.../amazon/aws/example_dags/example_datasync_1.py | 17 +-
.../amazon/aws/example_dags/example_datasync_2.py | 29 +-
.../amazon/aws/example_dags/example_ecs_fargate.py | 7 +-
.../example_emr_job_flow_automatic_steps.py | 14 +-
.../example_emr_job_flow_manual_steps.py | 18 +-
.../example_google_api_to_s3_transfer_advanced.py | 21 +-
.../example_google_api_to_s3_transfer_basic.py | 9 +-
.../example_dags/example_imap_attachment_to_s3.py | 7 +-
.../amazon/aws/example_dags/example_s3_bucket.py | 15 +-
.../aws/example_dags/example_s3_to_redshift.py | 24 +-
airflow/providers/amazon/aws/hooks/athena.py | 61 +-
airflow/providers/amazon/aws/hooks/aws_dynamodb.py | 4 +-
airflow/providers/amazon/aws/hooks/base_aws.py | 20 +-
airflow/providers/amazon/aws/hooks/batch_client.py | 16 +-
.../providers/amazon/aws/hooks/batch_waiters.py | 11 +-
airflow/providers/amazon/aws/hooks/datasync.py | 28 +-
airflow/providers/amazon/aws/hooks/ec2.py | 17 +-
airflow/providers/amazon/aws/hooks/emr.py | 4 +-
airflow/providers/amazon/aws/hooks/glue.py | 46 +-
airflow/providers/amazon/aws/hooks/glue_catalog.py | 12 +-
airflow/providers/amazon/aws/hooks/kinesis.py | 5 +-
.../providers/amazon/aws/hooks/lambda_function.py | 14 +-
airflow/providers/amazon/aws/hooks/logs.py | 12 +-
airflow/providers/amazon/aws/hooks/redshift.py | 34 +-
airflow/providers/amazon/aws/hooks/s3.py | 273 +++---
airflow/providers/amazon/aws/hooks/sagemaker.py | 283 +++---
airflow/providers/amazon/aws/hooks/ses.py | 2 +-
airflow/providers/amazon/aws/hooks/sns.py | 9 +-
airflow/providers/amazon/aws/hooks/sqs.py | 10 +-
.../providers/amazon/aws/hooks/step_function.py | 12 +-
.../amazon/aws/log/cloudwatch_task_handler.py | 16 +-
.../providers/amazon/aws/log/s3_task_handler.py | 8 +-
airflow/providers/amazon/aws/operators/athena.py | 33 +-
airflow/providers/amazon/aws/operators/batch.py | 11 +-
.../amazon/aws/operators/cloud_formation.py | 16 +-
airflow/providers/amazon/aws/operators/datasync.py | 90 +-
.../amazon/aws/operators/ec2_start_instance.py | 24 +-
.../amazon/aws/operators/ec2_stop_instance.py | 24 +-
airflow/providers/amazon/aws/operators/ecs.py | 73 +-
.../amazon/aws/operators/emr_add_steps.py | 22 +-
.../amazon/aws/operators/emr_create_job_flow.py | 24 +-
.../amazon/aws/operators/emr_modify_cluster.py | 13 +-
.../amazon/aws/operators/emr_terminate_job_flow.py | 7 +-
airflow/providers/amazon/aws/operators/glue.py | 58 +-
.../providers/amazon/aws/operators/s3_bucket.py | 27 +-
.../amazon/aws/operators/s3_copy_object.py | 33 +-
.../amazon/aws/operators/s3_delete_objects.py | 9 +-
.../amazon/aws/operators/s3_file_transform.py | 53 +-
airflow/providers/amazon/aws/operators/s3_list.py | 18 +-
.../amazon/aws/operators/sagemaker_base.py | 11 +-
.../amazon/aws/operators/sagemaker_endpoint.py | 39 +-
.../aws/operators/sagemaker_endpoint_config.py | 20 +-
.../amazon/aws/operators/sagemaker_model.py | 13 +-
.../amazon/aws/operators/sagemaker_processing.py | 35 +-
.../amazon/aws/operators/sagemaker_training.py | 29 +-
.../amazon/aws/operators/sagemaker_transform.py | 25 +-
.../amazon/aws/operators/sagemaker_tuning.py | 22 +-
airflow/providers/amazon/aws/operators/sns.py | 17 +-
airflow/providers/amazon/aws/operators/sqs.py | 28 +-
.../step_function_get_execution_output.py | 1 +
.../aws/operators/step_function_start_execution.py | 15 +-
.../amazon/aws/secrets/secrets_manager.py | 13 +-
.../amazon/aws/secrets/systems_manager.py | 9 +-
airflow/providers/amazon/aws/sensors/athena.py | 25 +-
.../amazon/aws/sensors/cloud_formation.py | 16 +-
.../amazon/aws/sensors/ec2_instance_state.py | 24 +-
airflow/providers/amazon/aws/sensors/emr_base.py | 15 +-
.../providers/amazon/aws/sensors/emr_job_flow.py | 17 +-
airflow/providers/amazon/aws/sensors/emr_step.py | 34 +-
airflow/providers/amazon/aws/sensors/glue.py | 11 +-
.../amazon/aws/sensors/glue_catalog_partition.py | 35 +-
airflow/providers/amazon/aws/sensors/redshift.py | 10 +-
airflow/providers/amazon/aws/sensors/s3_key.py | 30 +-
.../amazon/aws/sensors/s3_keys_unchanged.py | 52 +-
airflow/providers/amazon/aws/sensors/s3_prefix.py | 16 +-
.../providers/amazon/aws/sensors/sagemaker_base.py | 9 +-
.../amazon/aws/sensors/sagemaker_endpoint.py | 4 +-
.../amazon/aws/sensors/sagemaker_training.py | 30 +-
.../amazon/aws/sensors/sagemaker_transform.py | 4 +-
.../amazon/aws/sensors/sagemaker_tuning.py | 4 +-
airflow/providers/amazon/aws/sensors/sqs.py | 29 +-
.../amazon/aws/sensors/step_function_execution.py | 9 +-
.../amazon/aws/transfers/dynamodb_to_s3.py | 26 +-
.../providers/amazon/aws/transfers/gcs_to_s3.py | 60 +-
.../amazon/aws/transfers/google_api_to_s3.py | 13 +-
.../amazon/aws/transfers/hive_to_dynamodb.py | 44 +-
.../amazon/aws/transfers/imap_attachment_to_s3.py | 31 +-
.../providers/amazon/aws/transfers/mongo_to_s3.py | 38 +-
.../providers/amazon/aws/transfers/mysql_to_s3.py | 33 +-
.../amazon/aws/transfers/redshift_to_s3.py | 46 +-
.../amazon/aws/transfers/s3_to_redshift.py | 40 +-
.../providers/amazon/aws/transfers/s3_to_sftp.py | 10 +-
.../providers/amazon/aws/transfers/sftp_to_s3.py | 17 +-
.../example_dags/example_cassandra_dag.py | 2 +-
.../providers/apache/cassandra/hooks/cassandra.py | 26 +-
.../providers/apache/cassandra/sensors/record.py | 1 +
.../providers/apache/cassandra/sensors/table.py | 1 +
airflow/providers/apache/druid/hooks/druid.py | 33 +-
airflow/providers/apache/druid/operators/druid.py | 18 +-
.../apache/druid/operators/druid_check.py | 6 +-
.../apache/druid/transfers/hive_to_druid.py | 55 +-
airflow/providers/apache/hdfs/hooks/hdfs.py | 44 +-
airflow/providers/apache/hdfs/hooks/webhdfs.py | 23 +-
airflow/providers/apache/hdfs/sensors/hdfs.py | 66 +-
airflow/providers/apache/hdfs/sensors/web_hdfs.py | 8 +-
.../hive/example_dags/example_twitter_dag.py | 52 +-
airflow/providers/apache/hive/hooks/hive.py | 312 +++---
airflow/providers/apache/hive/operators/hive.py | 68 +-
.../providers/apache/hive/operators/hive_stats.py | 63 +-
.../apache/hive/sensors/hive_partition.py | 37 +-
.../apache/hive/sensors/metastore_partition.py | 20 +-
.../apache/hive/sensors/named_hive_partition.py | 28 +-
.../apache/hive/transfers/hive_to_mysql.py | 37 +-
.../apache/hive/transfers/hive_to_samba.py | 20 +-
.../apache/hive/transfers/mssql_to_hive.py | 28 +-
.../apache/hive/transfers/mysql_to_hive.py | 47 +-
.../providers/apache/hive/transfers/s3_to_hive.py | 105 +-
.../apache/hive/transfers/vertica_to_hive.py | 29 +-
.../apache/kylin/example_dags/example_kylin_dag.py | 8 +-
airflow/providers/apache/kylin/hooks/kylin.py | 23 +-
.../providers/apache/kylin/operators/kylin_cube.py | 71 +-
.../apache/livy/example_dags/example_livy.py | 15 +-
airflow/providers/apache/livy/hooks/livy.py | 61 +-
airflow/providers/apache/livy/operators/livy.py | 7 +-
airflow/providers/apache/livy/sensors/livy.py | 5 +-
.../apache/pig/example_dags/example_pig.py | 9 +-
airflow/providers/apache/pig/hooks/pig.py | 14 +-
airflow/providers/apache/pig/operators/pig.py | 22 +-
airflow/providers/apache/pinot/hooks/pinot.py | 111 ++-
.../apache/spark/example_dags/example_spark_dag.py | 15 +-
airflow/providers/apache/spark/hooks/spark_jdbc.py | 102 +-
.../apache/spark/hooks/spark_jdbc_script.py | 168 ++--
airflow/providers/apache/spark/hooks/spark_sql.py | 36 +-
.../providers/apache/spark/hooks/spark_submit.py | 218 ++---
.../providers/apache/spark/operators/spark_jdbc.py | 65 +-
.../providers/apache/spark/operators/spark_sql.py | 62 +-
.../apache/spark/operators/spark_submit.py | 81 +-
airflow/providers/apache/sqoop/hooks/sqoop.py | 186 ++--
airflow/providers/apache/sqoop/operators/sqoop.py | 126 +--
airflow/providers/celery/sensors/celery_queue.py | 17 +-
airflow/providers/cloudant/hooks/cloudant.py | 5 +-
.../kubernetes/example_dags/example_kubernetes.py | 92 +-
.../example_dags/example_spark_kubernetes.py | 5 +-
.../providers/cncf/kubernetes/hooks/kubernetes.py | 36 +-
.../cncf/kubernetes/operators/spark_kubernetes.py | 16 +-
.../cncf/kubernetes/sensors/spark_kubernetes.py | 16 +-
.../databricks/example_dags/example_databricks.py | 25 +-
airflow/providers/databricks/hooks/databricks.py | 64 +-
.../providers/databricks/operators/databricks.py | 90 +-
airflow/providers/datadog/hooks/datadog.py | 61 +-
airflow/providers/datadog/sensors/datadog.py | 24 +-
.../dingding/example_dags/example_dingding.py | 77 +-
airflow/providers/dingding/hooks/dingding.py | 53 +-
airflow/providers/dingding/operators/dingding.py | 24 +-
airflow/providers/discord/hooks/discord_webhook.py | 46 +-
.../providers/discord/operators/discord_webhook.py | 28 +-
.../docker/example_dags/example_docker.py | 20 +-
.../docker/example_dags/example_docker_swarm.py | 4 +-
airflow/providers/docker/hooks/docker.py | 22 +-
airflow/providers/docker/operators/docker.py | 91 +-
airflow/providers/docker/operators/docker_swarm.py | 15 +-
.../providers/elasticsearch/hooks/elasticsearch.py | 14 +-
.../providers/elasticsearch/log/es_task_handler.py | 42 +-
airflow/providers/exasol/hooks/exasol.py | 10 +-
airflow/providers/exasol/operators/exasol.py | 24 +-
airflow/providers/facebook/ads/hooks/ads.py | 29 +-
airflow/providers/ftp/hooks/ftp.py | 11 +-
airflow/providers/ftp/sensors/ftp.py | 14 +-
airflow/providers/google/__init__.py | 11 +-
.../google/ads/example_dags/example_ads.py | 4 +-
airflow/providers/google/ads/hooks/ads.py | 11 +-
airflow/providers/google/ads/operators/ads.py | 24 +-
.../providers/google/ads/transfers/ads_to_gcs.py | 31 +-
.../_internal_client/secret_manager_client.py | 19 +-
.../example_automl_nl_text_classification.py | 22 +-
.../example_automl_nl_text_extraction.py | 22 +-
.../example_automl_nl_text_sentiment.py | 22 +-
.../cloud/example_dags/example_automl_tables.py | 43 +-
.../example_dags/example_automl_translation.py | 27 +-
...ple_automl_video_intelligence_classification.py | 18 +-
.../example_automl_video_intelligence_tracking.py | 21 +-
.../example_automl_vision_classification.py | 22 +-
.../example_automl_vision_object_detection.py | 18 +-
.../cloud/example_dags/example_bigquery_dts.py | 18 +-
.../example_dags/example_bigquery_operations.py | 49 +-
.../cloud/example_dags/example_bigquery_queries.py | 38 +-
.../cloud/example_dags/example_bigquery_sensors.py | 25 +-
.../example_dags/example_bigquery_to_bigquery.py | 8 +-
.../cloud/example_dags/example_bigquery_to_gcs.py | 16 +-
.../example_dags/example_bigquery_transfer.py | 16 +-
.../google/cloud/example_dags/example_bigtable.py | 43 +-
.../cloud/example_dags/example_cloud_build.py | 5 +-
.../example_dags/example_cloud_memorystore.py | 15 +-
.../google/cloud/example_dags/example_cloud_sql.py | 185 +---
.../cloud/example_dags/example_cloud_sql_query.py | 215 ++---
.../example_cloud_storage_transfer_service_aws.py | 41 +-
.../example_cloud_storage_transfer_service_gcp.py | 38 +-
.../google/cloud/example_dags/example_compute.py | 37 +-
.../cloud/example_dags/example_compute_igm.py | 75 +-
.../cloud/example_dags/example_datacatalog.py | 32 +-
.../google/cloud/example_dags/example_dataflow.py | 48 +-
.../cloud/example_dags/example_datafusion.py | 54 +-
.../google/cloud/example_dags/example_dataprep.py | 4 +-
.../google/cloud/example_dags/example_dataproc.py | 30 +-
.../google/cloud/example_dags/example_datastore.py | 41 +-
.../google/cloud/example_dags/example_dlp.py | 12 +-
.../example_dags/example_facebook_ads_to_gcs.py | 32 +-
.../google/cloud/example_dags/example_functions.py | 46 +-
.../google/cloud/example_dags/example_gcs.py | 38 +-
.../cloud/example_dags/example_gcs_to_bigquery.py | 18 +-
.../cloud/example_dags/example_gcs_to_gcs.py | 22 +-
.../example_dags/example_kubernetes_engine.py | 14 +-
.../cloud/example_dags/example_life_sciences.py | 52 +-
.../cloud/example_dags/example_local_to_gcs.py | 7 +-
.../google/cloud/example_dags/example_mlengine.py | 76 +-
.../cloud/example_dags/example_natural_language.py | 18 +-
.../cloud/example_dags/example_postgres_to_gcs.py | 6 +-
.../cloud/example_dags/example_presto_to_gcs.py | 4 +-
.../google/cloud/example_dags/example_pubsub.py | 39 +-
.../cloud/example_dags/example_sftp_to_gcs.py | 4 +-
.../cloud/example_dags/example_sheets_to_gcs.py | 4 +-
.../google/cloud/example_dags/example_spanner.py | 79 +-
.../cloud/example_dags/example_speech_to_text.py | 4 +-
.../cloud/example_dags/example_stackdriver.py | 97 +-
.../google/cloud/example_dags/example_tasks.py | 4 +-
.../google/cloud/example_dags/example_translate.py | 3 +-
.../cloud/example_dags/example_translate_speech.py | 4 +-
.../example_dags/example_video_intelligence.py | 19 +-
.../google/cloud/example_dags/example_vision.py | 30 +-
airflow/providers/google/cloud/hooks/automl.py | 88 +-
airflow/providers/google/cloud/hooks/bigquery.py | 951 ++++++++++---------
.../providers/google/cloud/hooks/bigquery_dts.py | 32 +-
airflow/providers/google/cloud/hooks/bigtable.py | 60 +-
.../providers/google/cloud/hooks/cloud_build.py | 8 +-
.../google/cloud/hooks/cloud_memorystore.py | 4 +-
airflow/providers/google/cloud/hooks/cloud_sql.py | 366 ++++---
.../cloud/hooks/cloud_storage_transfer_service.py | 69 +-
airflow/providers/google/cloud/hooks/compute.py | 207 ++--
.../providers/google/cloud/hooks/datacatalog.py | 20 +-
airflow/providers/google/cloud/hooks/dataflow.py | 165 ++--
airflow/providers/google/cloud/hooks/datafusion.py | 119 +--
airflow/providers/google/cloud/hooks/dataproc.py | 97 +-
airflow/providers/google/cloud/hooks/datastore.py | 125 +--
airflow/providers/google/cloud/hooks/dlp.py | 29 +-
airflow/providers/google/cloud/hooks/functions.py | 41 +-
airflow/providers/google/cloud/hooks/gcs.py | 195 ++--
airflow/providers/google/cloud/hooks/gdm.py | 37 +-
airflow/providers/google/cloud/hooks/kms.py | 7 +-
.../google/cloud/hooks/kubernetes_engine.py | 85 +-
.../providers/google/cloud/hooks/life_sciences.py | 34 +-
airflow/providers/google/cloud/hooks/mlengine.py | 163 ++--
.../google/cloud/hooks/natural_language.py | 29 +-
airflow/providers/google/cloud/hooks/pubsub.py | 75 +-
.../providers/google/cloud/hooks/secret_manager.py | 16 +-
airflow/providers/google/cloud/hooks/spanner.py | 126 +--
.../providers/google/cloud/hooks/speech_to_text.py | 6 +-
.../providers/google/cloud/hooks/stackdriver.py | 113 +--
airflow/providers/google/cloud/hooks/tasks.py | 87 +-
.../providers/google/cloud/hooks/text_to_speech.py | 14 +-
airflow/providers/google/cloud/hooks/translate.py | 6 +-
.../google/cloud/hooks/video_intelligence.py | 7 +-
airflow/providers/google/cloud/hooks/vision.py | 80 +-
.../providers/google/cloud/log/gcs_task_handler.py | 15 +-
.../google/cloud/log/stackdriver_task_handler.py | 27 +-
airflow/providers/google/cloud/operators/automl.py | 213 +++--
.../providers/google/cloud/operators/bigquery.py | 414 ++++----
.../google/cloud/operators/bigquery_dts.py | 31 +-
.../providers/google/cloud/operators/bigtable.py | 290 +++---
.../google/cloud/operators/cloud_build.py | 26 +-
.../google/cloud/operators/cloud_memorystore.py | 124 ++-
.../providers/google/cloud/operators/cloud_sql.py | 695 ++++++++------
.../operators/cloud_storage_transfer_service.py | 174 +++-
.../providers/google/cloud/operators/compute.py | 388 +++++---
.../google/cloud/operators/datacatalog.py | 188 ++--
.../providers/google/cloud/operators/dataflow.py | 125 +--
.../providers/google/cloud/operators/datafusion.py | 148 +--
.../providers/google/cloud/operators/dataprep.py | 4 +-
.../providers/google/cloud/operators/dataproc.py | 517 +++++-----
.../providers/google/cloud/operators/datastore.py | 257 ++---
airflow/providers/google/cloud/operators/dlp.py | 442 +++++----
.../providers/google/cloud/operators/functions.py | 191 ++--
airflow/providers/google/cloud/operators/gcs.py | 297 +++---
.../google/cloud/operators/kubernetes_engine.py | 108 ++-
.../google/cloud/operators/life_sciences.py | 31 +-
.../providers/google/cloud/operators/mlengine.py | 406 ++++----
.../google/cloud/operators/natural_language.py | 60 +-
airflow/providers/google/cloud/operators/pubsub.py | 134 ++-
.../providers/google/cloud/operators/spanner.py | 365 ++++---
.../google/cloud/operators/speech_to_text.py | 19 +-
.../google/cloud/operators/stackdriver.py | 158 +--
airflow/providers/google/cloud/operators/tasks.py | 192 ++--
.../google/cloud/operators/text_to_speech.py | 12 +-
.../providers/google/cloud/operators/translate.py | 22 +-
.../google/cloud/operators/translate_speech.py | 34 +-
.../google/cloud/operators/video_intelligence.py | 54 +-
airflow/providers/google/cloud/operators/vision.py | 319 ++++---
.../google/cloud/secrets/secret_manager.py | 7 +-
airflow/providers/google/cloud/sensors/bigquery.py | 65 +-
.../providers/google/cloud/sensors/bigquery_dts.py | 9 +-
airflow/providers/google/cloud/sensors/bigtable.py | 22 +-
.../sensors/cloud_storage_transfer_service.py | 13 +-
airflow/providers/google/cloud/sensors/gcs.py | 130 ++-
airflow/providers/google/cloud/sensors/pubsub.py | 55 +-
.../google/cloud/transfers/adls_to_gcs.py | 54 +-
.../google/cloud/transfers/bigquery_to_bigquery.py | 59 +-
.../google/cloud/transfers/bigquery_to_gcs.py | 66 +-
.../google/cloud/transfers/bigquery_to_mysql.py | 65 +-
.../google/cloud/transfers/cassandra_to_gcs.py | 53 +-
.../google/cloud/transfers/facebook_ads_to_gcs.py | 23 +-
.../google/cloud/transfers/gcs_to_bigquery.py | 114 ++-
.../providers/google/cloud/transfers/gcs_to_gcs.py | 145 +--
.../google/cloud/transfers/gcs_to_local.py | 53 +-
.../google/cloud/transfers/gcs_to_sftp.py | 51 +-
.../google/cloud/transfers/local_to_gcs.py | 49 +-
.../google/cloud/transfers/mssql_to_gcs.py | 11 +-
.../google/cloud/transfers/mysql_to_gcs.py | 6 +-
.../google/cloud/transfers/postgres_to_gcs.py | 17 +-
.../google/cloud/transfers/presto_to_gcs.py | 6 +-
.../providers/google/cloud/transfers/s3_to_gcs.py | 85 +-
.../google/cloud/transfers/sftp_to_gcs.py | 36 +-
.../google/cloud/transfers/sheets_to_gcs.py | 38 +-
.../providers/google/cloud/transfers/sql_to_gcs.py | 106 ++-
.../google/cloud/utils/credentials_provider.py | 53 +-
.../google/cloud/utils/field_sanitizer.py | 15 +-
.../google/cloud/utils/field_validator.py | 185 ++--
.../google/cloud/utils/mlengine_operator_utils.py | 56 +-
.../cloud/utils/mlengine_prediction_summary.py | 56 +-
.../providers/google/common/hooks/base_google.py | 94 +-
.../providers/google/common/hooks/discovery_api.py | 11 +-
.../firebase/example_dags/example_firestore.py | 4 +-
.../providers/google/firebase/hooks/firestore.py | 7 +-
.../google/firebase/operators/firestore.py | 11 +-
.../example_dags/example_analytics.py | 17 +-
.../example_dags/example_campaign_manager.py | 33 +-
.../example_dags/example_display_video.py | 47 +-
.../example_dags/example_search_ads.py | 13 +-
.../google/marketing_platform/hooks/analytics.py | 30 +-
.../marketing_platform/hooks/campaign_manager.py | 19 +-
.../marketing_platform/hooks/display_video.py | 18 +-
.../google/marketing_platform/hooks/search_ads.py | 9 +-
.../marketing_platform/operators/analytics.py | 73 +-
.../operators/campaign_manager.py | 54 +-
.../marketing_platform/operators/display_video.py | 97 +-
.../marketing_platform/operators/search_ads.py | 32 +-
.../marketing_platform/sensors/campaign_manager.py | 16 +-
.../marketing_platform/sensors/display_video.py | 17 +-
.../marketing_platform/sensors/search_ads.py | 10 +-
.../suite/example_dags/example_gcs_to_sheets.py | 4 +-
.../google/suite/example_dags/example_sheets.py | 4 +-
airflow/providers/google/suite/hooks/drive.py | 4 +-
airflow/providers/google/suite/hooks/sheets.py | 175 ++--
airflow/providers/google/suite/operators/sheets.py | 8 +-
.../google/suite/transfers/gcs_to_gdrive.py | 13 +-
.../google/suite/transfers/gcs_to_sheets.py | 11 +-
airflow/providers/grpc/hooks/grpc.py | 42 +-
airflow/providers/grpc/operators/grpc.py | 27 +-
.../hashicorp/_internal_client/vault_client.py | 118 ++-
airflow/providers/hashicorp/hooks/vault.py | 79 +-
airflow/providers/hashicorp/secrets/vault.py | 5 +-
airflow/providers/http/hooks/http.py | 60 +-
airflow/providers/http/operators/http.py | 36 +-
airflow/providers/http/sensors/http.py | 35 +-
airflow/providers/imap/hooks/imap.py | 98 +-
airflow/providers/imap/sensors/imap_attachment.py | 20 +-
airflow/providers/jdbc/hooks/jdbc.py | 10 +-
airflow/providers/jdbc/operators/jdbc.py | 15 +-
.../example_dags/example_jenkins_job_trigger.py | 15 +-
.../jenkins/operators/jenkins_job_trigger.py | 76 +-
airflow/providers/jira/hooks/jira.py | 34 +-
airflow/providers/jira/operators/jira.py | 20 +-
airflow/providers/jira/sensors/jira.py | 73 +-
.../example_azure_container_instances.py | 2 +-
.../azure/example_dags/example_azure_cosmosdb.py | 2 +-
airflow/providers/microsoft/azure/hooks/adx.py | 41 +-
.../providers/microsoft/azure/hooks/azure_batch.py | 141 ++-
.../azure/hooks/azure_container_instance.py | 14 +-
.../azure/hooks/azure_container_volume.py | 17 +-
.../microsoft/azure/hooks/azure_cosmos.py | 97 +-
.../microsoft/azure/hooks/azure_data_lake.py | 50 +-
.../microsoft/azure/hooks/azure_fileshare.py | 30 +-
.../providers/microsoft/azure/hooks/base_azure.py | 16 +-
airflow/providers/microsoft/azure/hooks/wasb.py | 31 +-
.../microsoft/azure/log/wasb_task_handler.py | 19 +-
.../microsoft/azure/operators/adls_list.py | 15 +-
airflow/providers/microsoft/azure/operators/adx.py | 16 +-
.../microsoft/azure/operators/azure_batch.py | 206 ++--
.../azure/operators/azure_container_instances.py | 116 +--
.../microsoft/azure/operators/azure_cosmos.py | 16 +-
.../microsoft/azure/operators/wasb_delete_blob.py | 29 +-
.../microsoft/azure/sensors/azure_cosmos.py | 15 +-
airflow/providers/microsoft/azure/sensors/wasb.py | 43 +-
.../microsoft/azure/transfers/file_to_wasb.py | 27 +-
.../azure/transfers/oracle_to_azure_data_lake.py | 52 +-
.../providers/microsoft/mssql/operators/mssql.py | 5 +-
.../microsoft/winrm/example_dags/example_winrm.py | 18 +-
airflow/providers/microsoft/winrm/hooks/winrm.py | 52 +-
.../providers/microsoft/winrm/operators/winrm.py | 35 +-
airflow/providers/mongo/hooks/mongo.py | 147 ++-
airflow/providers/mongo/sensors/mongo.py | 14 +-
airflow/providers/mysql/hooks/mysql.py | 51 +-
airflow/providers/mysql/operators/mysql.py | 24 +-
.../providers/mysql/transfers/presto_to_mysql.py | 17 +-
airflow/providers/mysql/transfers/s3_to_mysql.py | 26 +-
.../providers/mysql/transfers/vertica_to_mysql.py | 41 +-
airflow/providers/odbc/hooks/odbc.py | 10 +-
airflow/providers/openfaas/hooks/openfaas.py | 5 +-
airflow/providers/opsgenie/hooks/opsgenie_alert.py | 21 +-
.../providers/opsgenie/operators/opsgenie_alert.py | 53 +-
airflow/providers/oracle/hooks/oracle.py | 22 +-
airflow/providers/oracle/operators/oracle.py | 19 +-
.../providers/oracle/transfers/oracle_to_oracle.py | 24 +-
airflow/providers/pagerduty/hooks/pagerduty.py | 3 +-
.../papermill/example_dags/example_papermill.py | 15 +-
airflow/providers/papermill/operators/papermill.py | 28 +-
airflow/providers/postgres/hooks/postgres.py | 33 +-
airflow/providers/postgres/operators/postgres.py | 19 +-
airflow/providers/presto/hooks/presto.py | 9 +-
.../qubole/example_dags/example_qubole.py | 85 +-
airflow/providers/qubole/hooks/qubole.py | 41 +-
airflow/providers/qubole/hooks/qubole_check.py | 6 +-
airflow/providers/qubole/operators/qubole.py | 56 +-
airflow/providers/qubole/operators/qubole_check.py | 31 +-
airflow/providers/qubole/sensors/qubole.py | 7 +-
airflow/providers/redis/hooks/redis.py | 23 +-
airflow/providers/redis/operators/redis_publish.py | 7 +-
airflow/providers/redis/sensors/redis_key.py | 1 +
airflow/providers/redis/sensors/redis_pub_sub.py | 1 +
.../example_tableau_refresh_workbook.py | 2 +-
airflow/providers/salesforce/hooks/salesforce.py | 38 +-
airflow/providers/salesforce/hooks/tableau.py | 7 +-
.../operators/tableau_refresh_workbook.py | 18 +-
.../salesforce/sensors/tableau_job_status.py | 13 +-
airflow/providers/samba/hooks/samba.py | 3 +-
airflow/providers/segment/hooks/segment.py | 10 +-
.../segment/operators/segment_track_event.py | 31 +-
airflow/providers/sendgrid/utils/emailer.py | 44 +-
airflow/providers/sftp/hooks/sftp.py | 24 +-
airflow/providers/sftp/operators/sftp.py | 64 +-
airflow/providers/sftp/sensors/sftp.py | 1 +
.../example_dags/example_singularity.py | 43 +-
.../providers/singularity/operators/singularity.py | 48 +-
airflow/providers/slack/hooks/slack.py | 8 +-
airflow/providers/slack/hooks/slack_webhook.py | 44 +-
airflow/providers/slack/operators/slack.py | 59 +-
airflow/providers/slack/operators/slack_webhook.py | 45 +-
.../snowflake/example_dags/example_snowflake.py | 19 +-
airflow/providers/snowflake/hooks/snowflake.py | 28 +-
airflow/providers/snowflake/operators/snowflake.py | 32 +-
.../snowflake/transfers/s3_to_snowflake.py | 37 +-
.../snowflake/transfers/snowflake_to_slack.py | 18 +-
airflow/providers/sqlite/operators/sqlite.py | 13 +-
airflow/providers/ssh/hooks/ssh.py | 101 +-
airflow/providers/ssh/operators/ssh.py | 70 +-
airflow/providers/vertica/hooks/vertica.py | 2 +-
airflow/providers/vertica/operators/vertica.py | 6 +-
.../example_dags/example_yandexcloud_dataproc.py | 70 +-
airflow/providers/yandex/hooks/yandex.py | 10 +-
.../providers/yandex/hooks/yandexcloud_dataproc.py | 3 +-
.../yandex/operators/yandexcloud_dataproc.py | 185 ++--
airflow/providers/zendesk/hooks/zendesk.py | 14 +-
tests/providers/amazon/aws/hooks/test_athena.py | 72 +-
.../amazon/aws/hooks/test_aws_dynamodb.py | 31 +-
tests/providers/amazon/aws/hooks/test_base_aws.py | 85 +-
.../amazon/aws/hooks/test_batch_client.py | 59 +-
.../amazon/aws/hooks/test_batch_waiters.py | 32 +-
.../amazon/aws/hooks/test_cloud_formation.py | 24 +-
tests/providers/amazon/aws/hooks/test_datasync.py | 45 +-
tests/providers/amazon/aws/hooks/test_ec2.py | 24 +-
tests/providers/amazon/aws/hooks/test_emr.py | 12 +-
tests/providers/amazon/aws/hooks/test_glue.py | 62 +-
.../amazon/aws/hooks/test_glue_catalog.py | 76 +-
tests/providers/amazon/aws/hooks/test_kinesis.py | 24 +-
.../amazon/aws/hooks/test_lambda_function.py | 12 +-
tests/providers/amazon/aws/hooks/test_logs.py | 25 +-
tests/providers/amazon/aws/hooks/test_redshift.py | 11 +-
tests/providers/amazon/aws/hooks/test_s3.py | 72 +-
tests/providers/amazon/aws/hooks/test_sagemaker.py | 468 ++++-----
tests/providers/amazon/aws/hooks/test_ses.py | 27 +-
tests/providers/amazon/aws/hooks/test_sns.py | 17 +-
tests/providers/amazon/aws/hooks/test_sqs.py | 1 -
.../amazon/aws/hooks/test_step_function.py | 13 +-
.../amazon/aws/log/test_cloudwatch_task_handler.py | 104 +-
.../amazon/aws/log/test_s3_task_handler.py | 45 +-
.../providers/amazon/aws/operators/test_athena.py | 93 +-
tests/providers/amazon/aws/operators/test_batch.py | 12 +-
.../amazon/aws/operators/test_cloud_formation.py | 30 +-
.../amazon/aws/operators/test_datasync.py | 131 +--
.../aws/operators/test_ec2_start_instance.py | 16 +-
.../amazon/aws/operators/test_ec2_stop_instance.py | 16 +-
tests/providers/amazon/aws/operators/test_ecs.py | 166 ++--
.../amazon/aws/operators/test_ecs_system.py | 15 +-
.../amazon/aws/operators/test_emr_add_steps.py | 120 +--
.../aws/operators/test_emr_create_job_flow.py | 83 +-
.../aws/operators/test_emr_modify_cluster.py | 20 +-
.../amazon/aws/operators/test_emr_system.py | 1 +
.../aws/operators/test_emr_terminate_job_flow.py | 10 +-
.../amazon/aws/operators/test_example_s3_bucket.py | 1 +
tests/providers/amazon/aws/operators/test_glue.py | 27 +-
.../amazon/aws/operators/test_s3_bucket.py | 10 +-
.../amazon/aws/operators/test_s3_copy_object.py | 41 +-
.../amazon/aws/operators/test_s3_delete_objects.py | 49 +-
.../amazon/aws/operators/test_s3_file_transform.py | 24 +-
.../providers/amazon/aws/operators/test_s3_list.py | 6 +-
.../amazon/aws/operators/test_sagemaker_base.py | 41 +-
.../aws/operators/test_sagemaker_endpoint.py | 64 +-
.../operators/test_sagemaker_endpoint_config.py | 18 +-
.../amazon/aws/operators/test_sagemaker_model.py | 20 +-
.../aws/operators/test_sagemaker_processing.py | 110 ++-
.../aws/operators/test_sagemaker_training.py | 114 ++-
.../aws/operators/test_sagemaker_transform.py | 67 +-
.../amazon/aws/operators/test_sagemaker_tuning.py | 163 ++--
tests/providers/amazon/aws/operators/test_sns.py | 1 -
tests/providers/amazon/aws/operators/test_sqs.py | 8 +-
.../test_step_function_get_execution_output.py | 21 +-
.../test_step_function_start_execution.py | 11 +-
.../amazon/aws/secrets/test_secrets_manager.py | 17 +-
.../amazon/aws/secrets/test_systems_manager.py | 43 +-
tests/providers/amazon/aws/sensors/test_athena.py | 13 +-
.../amazon/aws/sensors/test_cloud_formation.py | 13 +-
.../amazon/aws/sensors/test_ec2_instance_state.py | 32 +-
.../providers/amazon/aws/sensors/test_emr_base.py | 35 +-
.../amazon/aws/sensors/test_emr_job_flow.py | 132 +--
.../providers/amazon/aws/sensors/test_emr_step.py | 117 +--
tests/providers/amazon/aws/sensors/test_glue.py | 31 +-
.../aws/sensors/test_glue_catalog_partition.py | 45 +-
.../providers/amazon/aws/sensors/test_redshift.py | 44 +-
tests/providers/amazon/aws/sensors/test_s3_key.py | 32 +-
.../amazon/aws/sensors/test_s3_keys_unchanged.py | 27 +-
.../providers/amazon/aws/sensors/test_s3_prefix.py | 11 +-
.../amazon/aws/sensors/test_sagemaker_base.py | 50 +-
.../amazon/aws/sensors/test_sagemaker_endpoint.py | 34 +-
.../amazon/aws/sensors/test_sagemaker_training.py | 36 +-
.../amazon/aws/sensors/test_sagemaker_transform.py | 34 +-
.../amazon/aws/sensors/test_sagemaker_tuning.py | 34 +-
tests/providers/amazon/aws/sensors/test_sqs.py | 54 +-
.../aws/sensors/test_step_function_execution.py | 39 +-
.../amazon/aws/transfers/test_dynamodb_to_s3.py | 10 +-
.../amazon/aws/transfers/test_gcs_to_s3.py | 110 +--
.../amazon/aws/transfers/test_google_api_to_s3.py | 38 +-
.../aws/transfers/test_google_api_to_s3_system.py | 6 +-
.../amazon/aws/transfers/test_hive_to_dynamodb.py | 63 +-
.../aws/transfers/test_imap_attachment_to_s3.py | 7 +-
.../transfers/test_imap_attachment_to_s3_system.py | 6 +-
.../amazon/aws/transfers/test_mongo_to_s3.py | 24 +-
.../amazon/aws/transfers/test_mysql_to_s3.py | 26 +-
.../amazon/aws/transfers/test_redshift_to_s3.py | 32 +-
.../amazon/aws/transfers/test_s3_to_redshift.py | 20 +-
.../amazon/aws/transfers/test_s3_to_sftp.py | 17 +-
.../amazon/aws/transfers/test_sftp_to_s3.py | 16 +-
.../apache/cassandra/hooks/test_cassandra.py | 114 ++-
.../apache/cassandra/sensors/test_table.py | 8 +-
tests/providers/apache/druid/hooks/test_druid.py | 40 +-
.../providers/apache/druid/operators/test_druid.py | 12 +-
.../apache/druid/operators/test_druid_check.py | 7 +-
.../apache/druid/transfers/test_hive_to_druid.py | 50 +-
tests/providers/apache/hdfs/hooks/test_hdfs.py | 33 +-
tests/providers/apache/hdfs/hooks/test_webhdfs.py | 46 +-
tests/providers/apache/hdfs/sensors/test_hdfs.py | 197 ++--
.../providers/apache/hdfs/sensors/test_web_hdfs.py | 13 +-
tests/providers/apache/hive/__init__.py | 1 -
tests/providers/apache/hive/hooks/test_hive.py | 665 +++++++------
tests/providers/apache/hive/operators/test_hive.py | 229 +++--
.../apache/hive/operators/test_hive_stats.py | 246 +++--
tests/providers/apache/hive/sensors/test_hdfs.py | 11 +-
.../apache/hive/sensors/test_hive_partition.py | 18 +-
.../hive/sensors/test_metastore_partition.py | 10 +-
.../hive/sensors/test_named_hive_partition.py | 120 +--
.../apache/hive/transfers/test_hive_to_mysql.py | 58 +-
.../apache/hive/transfers/test_hive_to_samba.py | 37 +-
.../apache/hive/transfers/test_mssql_to_hive.py | 30 +-
.../apache/hive/transfers/test_mysql_to_hive.py | 280 ++++--
.../apache/hive/transfers/test_s3_to_hive.py | 115 ++-
.../apache/hive/transfers/test_vertica_to_hive.py | 27 +-
tests/providers/apache/kylin/hooks/test_kylin.py | 24 +-
.../apache/kylin/operators/test_kylin_cube.py | 89 +-
tests/providers/apache/livy/hooks/test_livy.py | 118 +--
tests/providers/apache/livy/operators/test_livy.py | 47 +-
tests/providers/apache/livy/sensors/test_livy.py | 16 +-
tests/providers/apache/pig/hooks/test_pig.py | 2 +-
tests/providers/apache/pinot/hooks/test_pinot.py | 129 ++-
.../apache/spark/hooks/test_spark_jdbc.py | 73 +-
.../apache/spark/hooks/test_spark_jdbc_script.py | 78 +-
.../providers/apache/spark/hooks/test_spark_sql.py | 107 ++-
.../apache/spark/hooks/test_spark_submit.py | 595 +++++++-----
.../apache/spark/operators/test_spark_jdbc.py | 26 +-
.../apache/spark/operators/test_spark_sql.py | 13 +-
.../apache/spark/operators/test_spark_submit.py | 81 +-
tests/providers/apache/sqoop/hooks/test_sqoop.py | 185 ++--
.../providers/apache/sqoop/operators/test_sqoop.py | 47 +-
.../providers/celery/sensors/test_celery_queue.py | 38 +-
tests/providers/cloudant/hooks/test_cloudant.py | 13 +-
.../cncf/kubernetes/hooks/test_kubernetes.py | 38 +-
.../kubernetes/operators/test_kubernetes_pod.py | 17 +-
.../kubernetes/operators/test_spark_kubernetes.py | 171 ++--
.../operators/test_spark_kubernetes_system.py | 16 +-
.../kubernetes/sensors/test_spark_kubernetes.py | 1002 +++++++++++---------
.../providers/databricks/hooks/test_databricks.py | 177 ++--
.../databricks/operators/test_databricks.py | 328 +++----
tests/providers/datadog/hooks/test_datadog.py | 21 +-
tests/providers/datadog/sensors/test_datadog.py | 77 +-
tests/providers/dingding/hooks/test_dingding.py | 127 +--
.../providers/dingding/operators/test_dingding.py | 15 +-
.../discord/hooks/test_discord_webhook.py | 7 +-
.../discord/operators/test_discord_webhook.py | 13 +-
tests/providers/docker/hooks/test_docker.py | 55 +-
tests/providers/docker/operators/test_docker.py | 155 +--
.../docker/operators/test_docker_swarm.py | 39 +-
.../elasticsearch/hooks/test_elasticsearch.py | 12 +-
.../elasticsearch/log/elasticmock/__init__.py | 6 +-
.../log/elasticmock/fake_elasticsearch.py | 222 +++--
.../elasticsearch/log/test_es_task_handler.py | 123 ++-
tests/providers/exasol/hooks/test_exasol.py | 8 +-
tests/providers/exasol/operators/test_exasol.py | 35 +-
tests/providers/facebook/ads/hooks/test_ads.py | 36 +-
tests/providers/ftp/hooks/test_ftp.py | 16 +-
tests/providers/ftp/sensors/test_ftp.py | 31 +-
tests/providers/google/ads/hooks/test_ads.py | 8 +-
tests/providers/google/ads/operators/test_ads.py | 7 +-
.../google/ads/transfers/test_ads_to_gcs.py | 15 +-
.../_internal_client/test_secret_manager_client.py | 15 +-
tests/providers/google/cloud/hooks/test_automl.py | 68 +-
.../providers/google/cloud/hooks/test_bigquery.py | 785 ++++++---------
.../google/cloud/hooks/test_bigquery_dts.py | 24 +-
.../google/cloud/hooks/test_bigquery_system.py | 7 +-
.../providers/google/cloud/hooks/test_bigtable.py | 150 ++-
.../google/cloud/hooks/test_cloud_build.py | 31 +-
.../google/cloud/hooks/test_cloud_memorystore.py | 98 +-
.../providers/google/cloud/hooks/test_cloud_sql.py | 731 +++++++-------
.../hooks/test_cloud_storage_transfer_service.py | 116 ++-
tests/providers/google/cloud/hooks/test_compute.py | 281 +++---
.../google/cloud/hooks/test_datacatalog.py | 271 ++----
.../providers/google/cloud/hooks/test_dataflow.py | 602 ++++++------
.../google/cloud/hooks/test_datafusion.py | 92 +-
.../providers/google/cloud/hooks/test_dataprep.py | 13 +-
.../providers/google/cloud/hooks/test_dataproc.py | 112 +--
.../providers/google/cloud/hooks/test_datastore.py | 190 ++--
tests/providers/google/cloud/hooks/test_dlp.py | 315 ++----
.../providers/google/cloud/hooks/test_functions.py | 142 ++-
tests/providers/google/cloud/hooks/test_gcs.py | 376 +++-----
tests/providers/google/cloud/hooks/test_gdm.py | 33 +-
tests/providers/google/cloud/hooks/test_kms.py | 14 +-
.../google/cloud/hooks/test_kubernetes_engine.py | 133 ++-
.../google/cloud/hooks/test_life_sciences.py | 89 +-
.../providers/google/cloud/hooks/test_mlengine.py | 622 ++++++------
.../google/cloud/hooks/test_natural_language.py | 34 +-
tests/providers/google/cloud/hooks/test_pubsub.py | 146 ++-
.../google/cloud/hooks/test_secret_manager.py | 16 +-
tests/providers/google/cloud/hooks/test_spanner.py | 170 ++--
.../google/cloud/hooks/test_speech_to_text.py | 5 +-
.../google/cloud/hooks/test_stackdriver.py | 164 ++--
tests/providers/google/cloud/hooks/test_tasks.py | 83 +-
.../google/cloud/hooks/test_text_to_speech.py | 5 +-
.../providers/google/cloud/hooks/test_translate.py | 5 +-
.../google/cloud/hooks/test_video_intelligence.py | 5 +-
tests/providers/google/cloud/hooks/test_vision.py | 45 +-
.../google/cloud/log/test_gcs_task_handler.py | 2 +-
.../cloud/log/test_gcs_task_handler_system.py | 29 +-
.../cloud/log/test_stackdriver_task_handler.py | 122 +--
.../log/test_stackdriver_task_handler_system.py | 33 +-
.../google/cloud/operators/test_automl.py | 63 +-
.../google/cloud/operators/test_bigquery.py | 556 +++++------
.../google/cloud/operators/test_bigquery_dts.py | 3 +-
.../cloud/operators/test_bigquery_dts_system.py | 19 +-
.../google/cloud/operators/test_bigtable.py | 359 ++++---
.../google/cloud/operators/test_bigtable_system.py | 20 +-
.../google/cloud/operators/test_cloud_build.py | 21 +-
.../cloud/operators/test_cloud_build_system.py | 1 +
.../cloud/operators/test_cloud_memorystore.py | 47 +-
.../google/cloud/operators/test_cloud_sql.py | 706 ++++++--------
.../cloud/operators/test_cloud_sql_system.py | 52 +-
.../operators/test_cloud_sql_system_helper.py | 457 +++++----
.../test_cloud_storage_transfer_service.py | 122 ++-
.../test_cloud_storage_transfer_service_system.py | 4 +-
.../google/cloud/operators/test_compute.py | 652 +++++--------
.../cloud/operators/test_compute_system_helper.py | 200 ++--
.../google/cloud/operators/test_datacatalog.py | 96 +-
.../google/cloud/operators/test_dataflow.py | 56 +-
.../google/cloud/operators/test_datafusion.py | 47 +-
.../google/cloud/operators/test_dataprep.py | 4 +-
.../google/cloud/operators/test_dataproc.py | 173 ++--
.../google/cloud/operators/test_dataproc_system.py | 8 +-
.../google/cloud/operators/test_datastore.py | 86 +-
.../cloud/operators/test_datastore_system.py | 1 -
tests/providers/google/cloud/operators/test_dlp.py | 231 ++---
.../google/cloud/operators/test_functions.py | 597 ++++++------
tests/providers/google/cloud/operators/test_gcs.py | 48 +-
.../cloud/operators/test_gcs_system_helper.py | 6 +-
.../cloud/operators/test_kubernetes_engine.py | 175 ++--
.../google/cloud/operators/test_life_sciences.py | 29 +-
.../cloud/operators/test_life_sciences_system.py | 7 +-
.../google/cloud/operators/test_mlengine.py | 292 +++---
.../google/cloud/operators/test_mlengine_system.py | 7 +-
.../google/cloud/operators/test_mlengine_utils.py | 55 +-
.../cloud/operators/test_natural_language.py | 11 +-
.../google/cloud/operators/test_pubsub.py | 84 +-
.../google/cloud/operators/test_spanner.py | 311 +++---
.../google/cloud/operators/test_spanner_system.py | 20 +-
.../google/cloud/operators/test_speech_to_text.py | 3 +-
.../cloud/operators/test_speech_to_text_system.py | 1 -
.../google/cloud/operators/test_stackdriver.py | 144 +--
.../providers/google/cloud/operators/test_tasks.py | 126 +--
.../google/cloud/operators/test_text_to_speech.py | 6 +-
.../cloud/operators/test_text_to_speech_system.py | 1 -
.../google/cloud/operators/test_translate.py | 3 +-
.../cloud/operators/test_translate_speech.py | 32 +-
.../operators/test_translate_speech_system.py | 1 -
.../cloud/operators/test_video_intelligence.py | 18 +-
.../operators/test_video_intelligence_system.py | 8 +-
.../google/cloud/operators/test_vision.py | 113 +--
.../google/cloud/secrets/test_secret_manager.py | 73 +-
.../google/cloud/sensors/test_bigquery.py | 9 +-
.../google/cloud/sensors/test_bigtable.py | 34 +-
tests/providers/google/cloud/sensors/test_gcs.py | 36 +-
.../providers/google/cloud/sensors/test_pubsub.py | 40 +-
.../google/cloud/transfers/test_adls_to_gcs.py | 49 +-
.../cloud/transfers/test_bigquery_to_bigquery.py | 28 +-
.../google/cloud/transfers/test_bigquery_to_gcs.py | 27 +-
.../cloud/transfers/test_bigquery_to_gcs_system.py | 1 -
.../cloud/transfers/test_bigquery_to_mysql.py | 9 +-
.../cloud/transfers/test_cassandra_to_gcs.py | 22 +-
.../cloud/transfers/test_facebook_ads_to_gcs.py | 46 +-
.../transfers/test_facebook_ads_to_gcs_system.py | 20 +-
.../google/cloud/transfers/test_gcs_to_bigquery.py | 41 +-
.../cloud/transfers/test_gcs_to_bigquery_system.py | 1 -
.../google/cloud/transfers/test_gcs_to_gcs.py | 303 +++---
.../cloud/transfers/test_gcs_to_gcs_system.py | 14 +-
.../google/cloud/transfers/test_gcs_to_local.py | 5 +-
.../google/cloud/transfers/test_gcs_to_sftp.py | 28 +-
.../cloud/transfers/test_gcs_to_sftp_system.py | 15 +-
.../google/cloud/transfers/test_local_to_gcs.py | 52 +-
.../cloud/transfers/test_local_to_gcs_system.py | 1 -
.../google/cloud/transfers/test_mssql_to_gcs.py | 33 +-
.../google/cloud/transfers/test_mysql_to_gcs.py | 125 ++-
.../google/cloud/transfers/test_postgres_to_gcs.py | 40 +-
.../google/cloud/transfers/test_presto_to_gcs.py | 8 +-
.../cloud/transfers/test_presto_to_gcs_system.py | 15 +-
.../google/cloud/transfers/test_s3_to_gcs.py | 22 +-
.../google/cloud/transfers/test_sftp_to_gcs.py | 17 +-
.../cloud/transfers/test_sftp_to_gcs_system.py | 8 +-
.../google/cloud/transfers/test_sheets_to_gcs.py | 26 +-
.../cloud/transfers/test_sheets_to_gcs_system.py | 1 -
.../google/cloud/transfers/test_sql_to_gcs.py | 102 +-
.../providers/google/cloud/utils/base_gcp_mock.py | 22 +-
.../google/cloud/utils/gcp_authenticator.py | 1 +
.../cloud/utils/test_credentials_provider.py | 184 ++--
.../google/cloud/utils/test_field_sanitizer.py | 105 +-
.../google/cloud/utils/test_field_validator.py | 64 +-
.../cloud/utils/test_mlengine_operator_utils.py | 161 ++--
.../utils/test_mlengine_prediction_summary.py | 41 +-
.../google/common/hooks/test_base_google.py | 214 ++---
.../google/common/hooks/test_discovery_api.py | 101 +-
.../common/utils/test_id_token_credentials.py | 3 +-
.../google/firebase/hooks/test_firestore.py | 57 +-
.../firebase/operators/test_firestore_system.py | 3 +-
.../marketing_platform/hooks/test_analytics.py | 97 +-
.../hooks/test_campaign_manager.py | 45 +-
.../marketing_platform/hooks/test_display_video.py | 123 +--
.../marketing_platform/hooks/test_search_ads.py | 33 +-
.../marketing_platform/operators/test_analytics.py | 84 +-
.../operators/test_campaign_manager.py | 139 +--
.../operators/test_display_video.py | 186 +---
.../operators/test_display_video_system.py | 3 +-
.../operators/test_search_ads.py | 68 +-
.../operators/test_search_ads_system.py | 3 +-
.../sensors/test_campaign_manager.py | 13 +-
.../sensors/test_display_video.py | 29 +-
.../marketing_platform/sensors/test_search_ads.py | 19 +-
tests/providers/google/suite/hooks/test_drive.py | 4 +-
tests/providers/google/suite/hooks/test_sheets.py | 67 +-
.../google/suite/operators/test_sheets.py | 4 +-
.../google/suite/transfers/test_gcs_to_gdrive.py | 6 +-
.../google/suite/transfers/test_gcs_to_sheets.py | 12 +-
.../suite/transfers/test_gcs_to_sheets_system.py | 1 -
tests/providers/grpc/hooks/test_grpc.py | 103 +-
tests/providers/grpc/operators/test_grpc.py | 18 +-
.../_internal_client/test_vault_client.py | 552 ++++++-----
tests/providers/hashicorp/hooks/test_vault.py | 369 +++----
tests/providers/hashicorp/secrets/test_vault.py | 96 +-
tests/providers/http/hooks/test_http.py | 137 +--
tests/providers/http/operators/test_http.py | 15 +-
tests/providers/http/operators/test_http_system.py | 5 +-
tests/providers/http/sensors/test_http.py | 43 +-
tests/providers/imap/hooks/test_imap.py | 86 +-
.../providers/imap/sensors/test_imap_attachment.py | 5 +-
tests/providers/jdbc/hooks/test_jdbc.py | 23 +-
tests/providers/jdbc/operators/test_jdbc.py | 10 +-
tests/providers/jenkins/hooks/test_jenkins.py | 25 +-
.../jenkins/operators/test_jenkins_job_trigger.py | 122 +--
tests/providers/jira/hooks/test_jira.py | 17 +-
tests/providers/jira/operators/test_jira.py | 65 +-
tests/providers/jira/sensors/test_jira.py | 36 +-
tests/providers/microsoft/azure/hooks/test_adx.py | 171 ++--
.../microsoft/azure/hooks/test_azure_batch.py | 93 +-
.../azure/hooks/test_azure_container_instance.py | 39 +-
.../azure/hooks/test_azure_container_registry.py | 1 -
.../azure/hooks/test_azure_container_volume.py | 17 +-
.../microsoft/azure/hooks/test_azure_cosmos.py | 54 +-
.../microsoft/azure/hooks/test_azure_data_lake.py | 91 +-
.../microsoft/azure/hooks/test_azure_fileshare.py | 59 +-
.../microsoft/azure/hooks/test_base_azure.py | 45 +-
tests/providers/microsoft/azure/hooks/test_wasb.py | 96 +-
.../microsoft/azure/log/test_wasb_task_handler.py | 50 +-
.../microsoft/azure/operators/test_adls_list.py | 17 +-
.../microsoft/azure/operators/test_adx.py | 46 +-
.../microsoft/azure/operators/test_azure_batch.py | 52 +-
.../operators/test_azure_container_instances.py | 115 +--
.../microsoft/azure/operators/test_azure_cosmos.py | 17 +-
.../azure/operators/test_wasb_delete_blob.py | 33 +-
.../providers/microsoft/azure/sensors/test_wasb.py | 56 +-
.../microsoft/azure/transfers/test_file_to_wasb.py | 31 +-
.../transfers/test_oracle_to_azure_data_lake.py | 19 +-
.../microsoft/mssql/operators/test_mssql.py | 4 +-
.../providers/microsoft/winrm/hooks/test_winrm.py | 26 +-
.../microsoft/winrm/operators/test_winrm.py | 10 +-
tests/providers/mongo/hooks/test_mongo.py | 34 +-
tests/providers/mongo/sensors/test_mongo.py | 15 +-
tests/providers/mysql/hooks/test_mysql.py | 85 +-
tests/providers/mysql/operators/test_mysql.py | 37 +-
.../mysql/transfers/test_presto_to_mysql.py | 23 +-
.../providers/mysql/transfers/test_s3_to_mysql.py | 33 +-
.../mysql/transfers/test_vertica_to_mysql.py | 69 +-
tests/providers/odbc/hooks/test_odbc.py | 18 +-
tests/providers/openfaas/hooks/test_openfaas.py | 28 +-
.../opsgenie/hooks/test_opsgenie_alert.py | 39 +-
.../opsgenie/operators/test_opsgenie_alert.py | 25 +-
tests/providers/oracle/hooks/test_oracle.py | 79 +-
tests/providers/oracle/operators/test_oracle.py | 9 +-
.../oracle/transfers/test_oracle_to_oracle.py | 12 +-
tests/providers/pagerduty/hooks/test_pagerduty.py | 48 +-
.../papermill/operators/test_papermill.py | 15 +-
tests/providers/postgres/hooks/test_postgres.py | 72 +-
.../providers/postgres/operators/test_postgres.py | 25 +-
tests/providers/presto/hooks/test_presto.py | 26 +-
tests/providers/qubole/operators/test_qubole.py | 54 +-
.../qubole/operators/test_qubole_check.py | 21 +-
tests/providers/qubole/sensors/test_qubole.py | 14 +-
tests/providers/redis/hooks/test_redis.py | 20 +-
.../redis/operators/test_redis_publish.py | 8 +-
tests/providers/redis/sensors/test_redis_key.py | 11 +-
.../providers/redis/sensors/test_redis_pub_sub.py | 49 +-
.../providers/salesforce/hooks/test_salesforce.py | 7 +-
tests/providers/salesforce/hooks/test_tableau.py | 20 +-
.../operators/test_tableau_refresh_workbook.py | 9 +-
.../salesforce/sensors/test_tableau_job_status.py | 11 +-
tests/providers/samba/hooks/test_samba.py | 25 +-
tests/providers/segment/hooks/test_segment.py | 2 -
.../segment/operators/test_segment_track_event.py | 12 +-
tests/providers/sendgrid/utils/test_emailer.py | 73 +-
tests/providers/sftp/hooks/test_sftp.py | 127 +--
tests/providers/sftp/operators/test_sftp.py | 120 +--
tests/providers/sftp/sensors/test_sftp.py | 43 +-
.../singularity/operators/test_singularity.py | 101 +-
tests/providers/slack/hooks/test_slack.py | 4 +-
tests/providers/slack/hooks/test_slack_webhook.py | 33 +-
tests/providers/slack/operators/test_slack.py | 43 +-
.../slack/operators/test_slack_webhook.py | 30 +-
tests/providers/snowflake/hooks/test_snowflake.py | 89 +-
.../snowflake/operators/test_snowflake.py | 9 +-
.../snowflake/operators/test_snowflake_system.py | 15 +-
.../snowflake/transfers/test_s3_to_snowflake.py | 21 +-
.../snowflake/transfers/test_snowflake_to_slack.py | 20 +-
tests/providers/sqlite/hooks/test_sqlite.py | 2 -
tests/providers/sqlite/operators/test_sqlite.py | 8 +-
tests/providers/ssh/hooks/test_ssh.py | 124 +--
tests/providers/ssh/operators/test_ssh.py | 76 +-
tests/providers/vertica/hooks/test_vertica.py | 18 +-
tests/providers/vertica/operators/test_vertica.py | 8 +-
tests/providers/yandex/hooks/test_yandex.py | 30 +-
.../yandex/hooks/test_yandexcloud_dataproc.py | 36 +-
.../yandex/operators/test_yandexcloud_dataproc.py | 138 ++-
tests/providers/zendesk/hooks/test_zendesk.py | 40 +-
873 files changed, 26384 insertions(+), 29360 deletions(-)
diff --git a/.flake8 b/.flake8
index cffaf32..14de564 100644
--- a/.flake8
+++ b/.flake8
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 110
-ignore = E231,E731,W504,I001,W503
+ignore = E203,E231,E731,W504,I001,W503
exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,.eggs,*.egg,node_modules
format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s
per-file-ignores =
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3197200..461f36d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -153,7 +153,8 @@ repos:
rev: stable
hooks:
- id: black
- files: api_connexion/.*\.py
+ files: api_connexion/.*\.py|.*providers.*\.py
+ exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$
args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
@@ -190,7 +191,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
- exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py
+ exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py
- repo: https://github.com/pycqa/pydocstyle
rev: 5.0.2
hooks:
diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync_1.py b/airflow/providers/amazon/aws/example_dags/example_datasync_1.py
index 5e1127a..8b3e278 100644
--- a/airflow/providers/amazon/aws/example_dags/example_datasync_1.py
+++ b/airflow/providers/amazon/aws/example_dags/example_datasync_1.py
@@ -33,16 +33,13 @@ from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator
from airflow.utils.dates import days_ago
# [START howto_operator_datasync_1_args_1]
-TASK_ARN = getenv(
- "TASK_ARN", "my_aws_datasync_task_arn")
+TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn")
# [END howto_operator_datasync_1_args_1]
# [START howto_operator_datasync_1_args_2]
-SOURCE_LOCATION_URI = getenv(
- "SOURCE_LOCATION_URI", "smb://hostname/directory/")
+SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
-DESTINATION_LOCATION_URI = getenv(
- "DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
+DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
# [END howto_operator_datasync_1_args_2]
@@ -55,16 +52,12 @@ with models.DAG(
# [START howto_operator_datasync_1_1]
datasync_task_1 = AWSDataSyncOperator(
- aws_conn_id="aws_default",
- task_id="datasync_task_1",
- task_arn=TASK_ARN
+ aws_conn_id="aws_default", task_id="datasync_task_1", task_arn=TASK_ARN
)
# [END howto_operator_datasync_1_1]
with models.DAG(
- "example_datasync_1_2",
- start_date=days_ago(1),
- schedule_interval=None, # Override to match your needs
+ "example_datasync_1_2", start_date=days_ago(1), schedule_interval=None, # Override to match your needs
) as dag:
# [START howto_operator_datasync_1_2]
datasync_task_2 = AWSDataSyncOperator(
diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync_2.py b/airflow/providers/amazon/aws/example_dags/example_datasync_2.py
index c6b8e0e..d4c7091 100644
--- a/airflow/providers/amazon/aws/example_dags/example_datasync_2.py
+++ b/airflow/providers/amazon/aws/example_dags/example_datasync_2.py
@@ -42,40 +42,30 @@ from airflow.providers.amazon.aws.operators.datasync import AWSDataSyncOperator
from airflow.utils.dates import days_ago
# [START howto_operator_datasync_2_args]
-SOURCE_LOCATION_URI = getenv(
- "SOURCE_LOCATION_URI", "smb://hostname/directory/")
+SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
-DESTINATION_LOCATION_URI = getenv(
- "DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
+DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
default_create_task_kwargs = '{"Name": "Created by Airflow"}'
-CREATE_TASK_KWARGS = json.loads(
- getenv("CREATE_TASK_KWARGS", default_create_task_kwargs)
-)
+CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", default_create_task_kwargs))
default_create_source_location_kwargs = "{}"
CREATE_SOURCE_LOCATION_KWARGS = json.loads(
- getenv("CREATE_SOURCE_LOCATION_KWARGS",
- default_create_source_location_kwargs)
+ getenv("CREATE_SOURCE_LOCATION_KWARGS", default_create_source_location_kwargs)
)
-bucket_access_role_arn = (
- "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"
-)
+bucket_access_role_arn = "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"
default_destination_location_kwargs = """\
{"S3BucketArn": "arn:aws:s3:::mybucket",
"S3Config": {"BucketAccessRoleArn":
"arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"}
}"""
CREATE_DESTINATION_LOCATION_KWARGS = json.loads(
- getenv("CREATE_DESTINATION_LOCATION_KWARGS",
- re.sub(r"[\s+]", '', default_destination_location_kwargs))
+ getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs))
)
default_update_task_kwargs = '{"Name": "Updated by Airflow"}'
-UPDATE_TASK_KWARGS = json.loads(
- getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs)
-)
+UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", default_update_task_kwargs))
# [END howto_operator_datasync_2_args]
@@ -92,13 +82,10 @@ with models.DAG(
task_id="datasync_task",
source_location_uri=SOURCE_LOCATION_URI,
destination_location_uri=DESTINATION_LOCATION_URI,
-
create_task_kwargs=CREATE_TASK_KWARGS,
create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS,
create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS,
-
update_task_kwargs=UPDATE_TASK_KWARGS,
-
- delete_task_after_execution=True
+ delete_task_after_execution=True,
)
# [END howto_operator_datasync_2]
diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
index 94cecba..cef3560 100644
--- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
+++ b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
@@ -56,12 +56,7 @@ hello_world = ECSOperator(
task_definition="hello-world",
launch_type="FARGATE",
overrides={
- "containerOverrides": [
- {
- "name": "hello-world-container",
- "command": ["echo", "hello", "world"],
- },
- ],
+ "containerOverrides": [{"name": "hello-world-container", "command": ["echo", "hello", "world"],},],
},
network_configuration={
"awsvpcConfiguration": {
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
index 3c52ffc..3077944 100644
--- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
+++ b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
@@ -30,7 +30,7 @@ DEFAULT_ARGS = {
'depends_on_past': False,
'email': ['airflow@example.com'],
'email_on_failure': False,
- 'email_on_retry': False
+ 'email_on_retry': False,
}
# [START howto_operator_emr_automatic_steps_config]
@@ -40,12 +40,8 @@ SPARK_STEPS = [
'ActionOnFailure': 'CONTINUE',
'HadoopJarStep': {
'Jar': 'command-runner.jar',
- 'Args': [
- '/usr/lib/spark/bin/run-example',
- 'SparkPi',
- '10'
- ]
- }
+ 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'],
+ },
}
]
@@ -85,13 +81,13 @@ with DAG(
task_id='create_job_flow',
job_flow_overrides=JOB_FLOW_OVERRIDES,
aws_conn_id='aws_default',
- emr_conn_id='emr_default'
+ emr_conn_id='emr_default',
)
job_sensor = EmrJobFlowSensor(
task_id='check_job_flow',
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
- aws_conn_id='aws_default'
+ aws_conn_id='aws_default',
)
job_flow_creator >> job_sensor
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
index 0b73bd3..1eb857a 100644
--- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
+++ b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
@@ -35,7 +35,7 @@ DEFAULT_ARGS = {
'depends_on_past': False,
'email': ['airflow@example.com'],
'email_on_failure': False,
- 'email_on_retry': False
+ 'email_on_retry': False,
}
SPARK_STEPS = [
@@ -44,12 +44,8 @@ SPARK_STEPS = [
'ActionOnFailure': 'CONTINUE',
'HadoopJarStep': {
'Jar': 'command-runner.jar',
- 'Args': [
- '/usr/lib/spark/bin/run-example',
- 'SparkPi',
- '10'
- ]
- }
+ 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'],
+ },
}
]
@@ -87,27 +83,27 @@ with DAG(
task_id='create_job_flow',
job_flow_overrides=JOB_FLOW_OVERRIDES,
aws_conn_id='aws_default',
- emr_conn_id='emr_default'
+ emr_conn_id='emr_default',
)
step_adder = EmrAddStepsOperator(
task_id='add_steps',
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
aws_conn_id='aws_default',
- steps=SPARK_STEPS
+ steps=SPARK_STEPS,
)
step_checker = EmrStepSensor(
task_id='watch_step',
job_flow_id="{{ task_instance.xcom_pull('create_job_flow', key='return_value') }}",
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
- aws_conn_id='aws_default'
+ aws_conn_id='aws_default',
)
cluster_remover = EmrTerminateJobFlowOperator(
task_id='remove_cluster',
job_flow_id="{{ task_instance.xcom_pull(task_ids='create_job_flow', key='return_value') }}",
- aws_conn_id='aws_default'
+ aws_conn_id='aws_default',
)
cluster_creator >> step_adder >> step_checker >> cluster_remover
diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py
index f05c5ae..fc14199 100644
--- a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py
+++ b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_advanced.py
@@ -74,7 +74,7 @@ with DAG(
dag_id="example_google_api_to_s3_transfer_advanced",
schedule_interval=None,
start_date=days_ago(1),
- tags=['example']
+ tags=['example'],
) as dag:
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1]
task_video_ids_to_s3 = GoogleApiToS3Operator(
@@ -89,21 +89,18 @@ with DAG(
'publishedAfter': YOUTUBE_VIDEO_PUBLISHED_AFTER,
'publishedBefore': YOUTUBE_VIDEO_PUBLISHED_BEFORE,
'type': 'video',
- 'fields': 'items/id/videoId'
+ 'fields': 'items/id/videoId',
},
google_api_response_via_xcom='video_ids_response',
s3_destination_key=f'{s3_directory}/youtube_search_{s3_file_name}.json',
- task_id='video_ids_to_s3'
+ task_id='video_ids_to_s3',
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
task_check_and_transform_video_ids = BranchPythonOperator(
python_callable=_check_and_transform_video_ids,
- op_args=[
- task_video_ids_to_s3.google_api_response_via_xcom,
- task_video_ids_to_s3.task_id
- ],
- task_id='check_and_transform_video_ids'
+ op_args=[task_video_ids_to_s3.google_api_response_via_xcom, task_video_ids_to_s3.task_id],
+ task_id='check_and_transform_video_ids',
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_1_1]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2]
@@ -115,16 +112,14 @@ with DAG(
google_api_endpoint_params={
'part': YOUTUBE_VIDEO_PARTS,
'maxResults': 50,
- 'fields': YOUTUBE_VIDEO_FIELDS
+ 'fields': YOUTUBE_VIDEO_FIELDS,
},
google_api_endpoint_params_via_xcom='video_ids',
s3_destination_key=f'{s3_directory}/youtube_videos_{s3_file_name}.json',
- task_id='video_data_to_s3'
+ task_id='video_data_to_s3',
)
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2]
# [START howto_operator_google_api_to_s3_transfer_advanced_task_2_1]
- task_no_video_ids = DummyOperator(
- task_id='no_video_ids'
- )
+ task_no_video_ids = DummyOperator(task_id='no_video_ids')
# [END howto_operator_google_api_to_s3_transfer_advanced_task_2_1]
task_video_ids_to_s3 >> task_check_and_transform_video_ids >> [task_video_data_to_s3, task_no_video_ids]
diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py
index 515a966..f5c1ec1 100644
--- a/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py
+++ b/airflow/providers/amazon/aws/example_dags/example_google_api_to_s3_transfer_basic.py
@@ -37,19 +37,16 @@ with DAG(
dag_id="example_google_api_to_s3_transfer_basic",
schedule_interval=None,
start_date=days_ago(1),
- tags=['example']
+ tags=['example'],
) as dag:
# [START howto_operator_google_api_to_s3_transfer_basic_task_1]
task_google_sheets_values_to_s3 = GoogleApiToS3Operator(
google_api_service_name='sheets',
google_api_service_version='v4',
google_api_endpoint_path='sheets.spreadsheets.values.get',
- google_api_endpoint_params={
- 'spreadsheetId': GOOGLE_SHEET_ID,
- 'range': GOOGLE_SHEET_RANGE
- },
+ google_api_endpoint_params={'spreadsheetId': GOOGLE_SHEET_ID, 'range': GOOGLE_SHEET_RANGE},
s3_destination_key=S3_DESTINATION_KEY,
task_id='google_sheets_values_to_s3',
- dag=dag
+ dag=dag,
)
# [END howto_operator_google_api_to_s3_transfer_basic_task_1]
diff --git a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
index 0c308ba..cfb8b95 100644
--- a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
@@ -34,10 +34,7 @@ S3_DESTINATION_KEY = getenv("S3_DESTINATION_KEY", "s3://bucket/key.json")
# [END howto_operator_imap_attachment_to_s3_env_variables]
with DAG(
- dag_id="example_imap_attachment_to_s3",
- start_date=days_ago(1),
- schedule_interval=None,
- tags=['example']
+ dag_id="example_imap_attachment_to_s3", start_date=days_ago(1), schedule_interval=None, tags=['example']
) as dag:
# [START howto_operator_imap_attachment_to_s3_task_1]
task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator(
@@ -46,6 +43,6 @@ with DAG(
imap_mail_folder=IMAP_MAIL_FOLDER,
imap_mail_filter=IMAP_MAIL_FILTER,
task_id='transfer_imap_attachment_to_s3',
- dag=dag
+ dag=dag,
)
# [END howto_operator_imap_attachment_to_s3_task_1]
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py b/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py
index 0321cfa..591ba0e 100644
--- a/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py
+++ b/airflow/providers/amazon/aws/example_dags/example_s3_bucket.py
@@ -31,9 +31,7 @@ def upload_keys():
s3_hook = S3Hook()
for i in range(0, 3):
s3_hook.load_string(
- string_data="input",
- key=f"path/data{i}",
- bucket_name=BUCKET_NAME,
+ string_data="input", key=f"path/data{i}", bucket_name=BUCKET_NAME,
)
@@ -46,20 +44,15 @@ with DAG(
) as dag:
create_bucket = S3CreateBucketOperator(
- task_id='s3_bucket_dag_create',
- bucket_name=BUCKET_NAME,
- region_name='us-east-1',
+ task_id='s3_bucket_dag_create', bucket_name=BUCKET_NAME, region_name='us-east-1',
)
add_keys_to_bucket = PythonOperator(
- task_id="s3_bucket_dag_add_keys_to_bucket",
- python_callable=upload_keys
+ task_id="s3_bucket_dag_add_keys_to_bucket", python_callable=upload_keys
)
delete_bucket = S3DeleteBucketOperator(
- task_id='s3_bucket_dag_delete',
- bucket_name=BUCKET_NAME,
- force_delete=True,
+ task_id='s3_bucket_dag_delete', bucket_name=BUCKET_NAME, force_delete=True,
)
create_bucket >> add_keys_to_bucket >> delete_bucket
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py
index 2ffccbc..76c79e5 100644
--- a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py
@@ -47,19 +47,15 @@ def _remove_sample_data_from_s3():
with DAG(
- dag_id="example_s3_to_redshift",
- start_date=days_ago(1),
- schedule_interval=None,
- tags=['example']
+ dag_id="example_s3_to_redshift", start_date=days_ago(1), schedule_interval=None, tags=['example']
) as dag:
setup__task_add_sample_data_to_s3 = PythonOperator(
- python_callable=_add_sample_data_to_s3,
- task_id='setup__add_sample_data_to_s3'
+ python_callable=_add_sample_data_to_s3, task_id='setup__add_sample_data_to_s3'
)
setup__task_create_table = PostgresOperator(
sql=f'CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)',
postgres_conn_id='redshift_default',
- task_id='setup__create_table'
+ task_id='setup__create_table',
)
# [START howto_operator_s3_to_redshift_task_1]
task_transfer_s3_to_redshift = S3ToRedshiftOperator(
@@ -68,22 +64,18 @@ with DAG(
schema="PUBLIC",
table=REDSHIFT_TABLE,
copy_options=['csv'],
- task_id='transfer_s3_to_redshift'
+ task_id='transfer_s3_to_redshift',
)
# [END howto_operator_s3_to_redshift_task_1]
teardown__task_drop_table = PostgresOperator(
sql=f'DROP TABLE IF EXISTS {REDSHIFT_TABLE}',
postgres_conn_id='redshift_default',
- task_id='teardown__drop_table'
+ task_id='teardown__drop_table',
)
teardown__task_remove_sample_data_from_s3 = PythonOperator(
- python_callable=_remove_sample_data_from_s3,
- task_id='teardown__remove_sample_data_from_s3'
+ python_callable=_remove_sample_data_from_s3, task_id='teardown__remove_sample_data_from_s3'
)
- [
- setup__task_add_sample_data_to_s3,
- setup__task_create_table
- ] >> task_transfer_s3_to_redshift >> [
+ [setup__task_add_sample_data_to_s3, setup__task_create_table] >> task_transfer_s3_to_redshift >> [
teardown__task_drop_table,
- teardown__task_remove_sample_data_from_s3
+ teardown__task_remove_sample_data_from_s3,
]
diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index 830bb8f..a7fb947 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -41,22 +41,28 @@ class AWSAthenaHook(AwsBaseHook):
:type sleep_time: int
"""
- INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',)
- FAILURE_STATES = ('FAILED', 'CANCELLED',)
+ INTERMEDIATE_STATES = (
+ 'QUEUED',
+ 'RUNNING',
+ )
+ FAILURE_STATES = (
+ 'FAILED',
+ 'CANCELLED',
+ )
SUCCESS_STATES = ('SUCCEEDED',)
- def __init__(self,
- *args: Any,
- sleep_time: int = 30,
- **kwargs: Any) -> None:
+ def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
super().__init__(client_type='athena', *args, **kwargs) # type: ignore
self.sleep_time = sleep_time
- def run_query(self, query: str,
- query_context: Dict[str, str],
- result_configuration: Dict[str, Any],
- client_request_token: Optional[str] = None,
- workgroup: str = 'primary') -> str:
+ def run_query(
+ self,
+ query: str,
+ query_context: Dict[str, str],
+ result_configuration: Dict[str, Any],
+ client_request_token: Optional[str] = None,
+ workgroup: str = 'primary',
+ ) -> str:
"""
Run Presto query on athena with provided config and return submitted query_execution_id
@@ -76,7 +82,7 @@ class AWSAthenaHook(AwsBaseHook):
'QueryString': query,
'QueryExecutionContext': query_context,
'ResultConfiguration': result_configuration,
- 'WorkGroup': workgroup
+ 'WorkGroup': workgroup,
}
if client_request_token:
params['ClientRequestToken'] = client_request_token
@@ -122,9 +128,9 @@ class AWSAthenaHook(AwsBaseHook):
# The error is being absorbed to implement retries.
return reason # pylint: disable=lost-exception
- def get_query_results(self, query_execution_id: str,
- next_token_id: Optional[str] = None,
- max_results: int = 1000) -> Optional[dict]:
+ def get_query_results(
+ self, query_execution_id: str, next_token_id: Optional[str] = None, max_results: int = 1000
+ ) -> Optional[dict]:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
@@ -144,19 +150,18 @@ class AWSAthenaHook(AwsBaseHook):
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
- result_params = {
- 'QueryExecutionId': query_execution_id,
- 'MaxResults': max_results
- }
+ result_params = {'QueryExecutionId': query_execution_id, 'MaxResults': max_results}
if next_token_id:
result_params['NextToken'] = next_token_id
return self.get_conn().get_query_results(**result_params)
- def get_query_results_paginator(self, query_execution_id: str,
- max_items: Optional[int] = None,
- page_size: Optional[int] = None,
- starting_token: Optional[str] = None
- ) -> Optional[PageIterator]:
+ def get_query_results_paginator(
+ self,
+ query_execution_id: str,
+ max_items: Optional[int] = None,
+ page_size: Optional[int] = None,
+ starting_token: Optional[str] = None,
+ ) -> Optional[PageIterator]:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else a paginator to iterate through pages of results. If you
@@ -184,15 +189,13 @@ class AWSAthenaHook(AwsBaseHook):
'PaginationConfig': {
'MaxItems': max_items,
'PageSize': page_size,
- 'StartingToken': starting_token
-
- }
+ 'StartingToken': starting_token,
+ },
}
paginator = self.get_conn().get_paginator('get_query_results')
return paginator.paginate(**result_params)
- def poll_query_status(self, query_execution_id: str,
- max_tries: Optional[int] = None) -> Optional[str]:
+ def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]:
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
diff --git a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py
index fd5dde3..f197aa7 100644
--- a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py
+++ b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py
@@ -58,7 +58,5 @@ class AwsDynamoDBHook(AwsBaseHook):
return True
except Exception as general_error:
raise AirflowException(
- 'Failed to insert items in dynamodb, error: {error}'.format(
- error=str(general_error)
- )
+ 'Failed to insert items in dynamodb, error: {error}'.format(error=str(general_error))
)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index 350ff7e..c2348d3 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -189,6 +189,7 @@ class _SessionFactory(LoggingMixin):
def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]):
import requests
+
# requests_gssapi will need paramiko > 2.6 since you'll need
# 'gssapi' not 'python-gssapi' from PyPi.
# https://github.com/paramiko/paramiko/pull/1311
@@ -269,7 +270,7 @@ class AwsBaseHook(BaseHook):
region_name: Optional[str] = None,
client_type: Optional[str] = None,
resource_type: Optional[str] = None,
- config: Optional[Config] = None
+ config: Optional[Config] = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
@@ -280,9 +281,7 @@ class AwsBaseHook(BaseHook):
self.config = config
if not (self.client_type or self.resource_type):
- raise AirflowException(
- 'Either client_type or resource_type'
- ' must be provided.')
+ raise AirflowException('Either client_type or resource_type' ' must be provided.')
def _get_credentials(self, region_name):
@@ -302,7 +301,7 @@ class AwsBaseHook(BaseHook):
if "config_kwargs" in extra_config:
self.log.info(
"Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
- extra_config["config_kwargs"]
+ extra_config["config_kwargs"],
)
self.config = Config(**extra_config["config_kwargs"])
@@ -318,8 +317,7 @@ class AwsBaseHook(BaseHook):
# http://boto3.readthedocs.io/en/latest/guide/configuration.html
self.log.info(
- "Creating session using boto3 credential strategy region_name=%s",
- region_name,
+ "Creating session using boto3 credential strategy region_name=%s", region_name,
)
session = boto3.session.Session(region_name=region_name)
return session, None
@@ -333,9 +331,7 @@ class AwsBaseHook(BaseHook):
if config is None:
config = self.config
- return session.client(
- client_type, endpoint_url=endpoint_url, config=config, verify=self.verify
- )
+ return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
def get_resource_type(self, resource_type, region_name=None, config=None):
"""Get the underlying boto3 resource using boto3 session"""
@@ -346,9 +342,7 @@ class AwsBaseHook(BaseHook):
if config is None:
config = self.config
- return session.resource(
- resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify
- )
+ return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
@cached_property
def conn(self):
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py
index 37a8ce0..d53c6f5 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -199,11 +199,7 @@ class AwsBatchClientHook(AwsBaseHook):
DEFAULT_DELAY_MAX = 10
def __init__(
- self,
- *args,
- max_retries: Optional[int] = None,
- status_retries: Optional[int] = None,
- **kwargs
+ self, *args, max_retries: Optional[int] = None, status_retries: Optional[int] = None, **kwargs
):
# https://github.com/python/mypy/issues/6799 hence type: ignore
super().__init__(client_type='batch', *args, **kwargs) # type: ignore
@@ -211,7 +207,7 @@ class AwsBatchClientHook(AwsBaseHook):
self.status_retries = status_retries or self.STATUS_RETRIES
@property
- def client(self) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402
+ def client(self) -> Union[AwsBatchProtocol, botocore.client.BaseClient]: # noqa: D402
"""
An AWS API client for batch services, like ``boto3.client('batch')``
@@ -353,9 +349,7 @@ class AwsBatchClientHook(AwsBaseHook):
return True
if retries >= self.max_retries:
- raise AirflowException(
- "AWS Batch job ({}) status checks exceed max_retries".format(job_id)
- )
+ raise AirflowException("AWS Batch job ({}) status checks exceed max_retries".format(job_id))
retries += 1
pause = self.exponential_delay(retries)
@@ -391,9 +385,7 @@ class AwsBatchClientHook(AwsBaseHook):
if error.get("Code") == "TooManyRequestsException":
pass # allow it to retry, if possible
else:
- raise AirflowException(
- "AWS Batch job ({}) description error: {}".format(job_id, err)
- )
+ raise AirflowException("AWS Batch job ({}) description error: {}".format(job_id, err))
retries += 1
if retries >= self.status_retries:
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index 75bfb58..d4e91d9 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -102,12 +102,7 @@ class AwsBatchWaitersHook(AwsBatchClientHook):
:type region_name: Optional[str]
"""
- def __init__(
- self,
- *args,
- waiter_config: Optional[Dict] = None,
- **kwargs
- ):
+ def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs):
super().__init__(*args, **kwargs)
@@ -183,9 +178,7 @@ class AwsBatchWaitersHook(AwsBatchClientHook):
:return: a waiter object for the named AWS batch service
:rtype: botocore.waiter.Waiter
"""
- return botocore.waiter.create_waiter_with_client(
- waiter_name, self.waiter_model, self.client
- )
+ return botocore.waiter.create_waiter_with_client(waiter_name, self.waiter_model, self.client)
def list_waiters(self) -> List[str]:
"""
diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py
index 153a75f..b6ef08e 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -58,8 +58,7 @@ class AWSDataSyncHook(AwsBaseHook):
self.tasks = []
# wait_interval_seconds = 0 is used during unit tests
if wait_interval_seconds < 0 or wait_interval_seconds > 15 * 60:
- raise ValueError("Invalid wait_interval_seconds %s" %
- wait_interval_seconds)
+ raise ValueError("Invalid wait_interval_seconds %s" % wait_interval_seconds)
self.wait_interval_seconds = wait_interval_seconds
def create_location(self, location_uri, **create_location_kwargs):
@@ -85,9 +84,7 @@ class AWSDataSyncHook(AwsBaseHook):
self._refresh_locations()
return location["LocationArn"]
- def get_location_arns(
- self, location_uri, case_sensitive=False, ignore_trailing_slash=True
- ):
+ def get_location_arns(self, location_uri, case_sensitive=False, ignore_trailing_slash=True):
"""
Return all LocationArns which match a LocationUri.
@@ -133,9 +130,7 @@ class AWSDataSyncHook(AwsBaseHook):
break
next_token = locations["NextToken"]
- def create_task(
- self, source_location_arn, destination_location_arn, **create_task_kwargs
- ):
+ def create_task(self, source_location_arn, destination_location_arn, **create_task_kwargs):
r"""Create a Task between the specified source and destination LocationArns.
:param str source_location_arn: Source LocationArn. Must exist already.
@@ -147,7 +142,7 @@ class AWSDataSyncHook(AwsBaseHook):
task = self.get_conn().create_task(
SourceLocationArn=source_location_arn,
DestinationLocationArn=destination_location_arn,
- **create_task_kwargs
+ **create_task_kwargs,
)
self._refresh_tasks()
return task["TaskArn"]
@@ -181,9 +176,7 @@ class AWSDataSyncHook(AwsBaseHook):
break
next_token = tasks["NextToken"]
- def get_task_arns_for_location_arns(
- self, source_location_arns, destination_location_arns
- ):
+ def get_task_arns_for_location_arns(self, source_location_arns, destination_location_arns):
"""
Return list of TaskArns for which use any one of the specified
source LocationArns and any one of the specified destination LocationArns.
@@ -224,9 +217,7 @@ class AWSDataSyncHook(AwsBaseHook):
"""
if not task_arn:
raise AirflowBadRequest("task_arn not specified")
- task_execution = self.get_conn().start_task_execution(
- TaskArn=task_arn, **kwargs
- )
+ task_execution = self.get_conn().start_task_execution(TaskArn=task_arn, **kwargs)
return task_execution["TaskExecutionArn"]
def cancel_task_execution(self, task_execution_arn):
@@ -298,9 +289,7 @@ class AWSDataSyncHook(AwsBaseHook):
status = None
iterations = max_iterations
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
- task_execution = self.get_conn().describe_task_execution(
- TaskExecutionArn=task_execution_arn
- )
+ task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
status = task_execution["Status"]
self.log.info("status=%s", status)
iterations -= 1
@@ -318,5 +307,4 @@ class AWSDataSyncHook(AwsBaseHook):
return False
if iterations <= 0:
raise AirflowTaskTimeout("Max iterations exceeded!")
- raise AirflowException("Unknown status: %s" %
- status) # Should never happen
+ raise AirflowException("Unknown status: %s" % status) # Should never happen
diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py
index f8120c3..34517d7 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -33,9 +33,7 @@ class EC2Hook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- def __init__(self,
- *args,
- **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__(resource_type="ec2", *args, **kwargs)
def get_instance(self, instance_id: str):
@@ -60,10 +58,7 @@ class EC2Hook(AwsBaseHook):
"""
return self.get_instance(instance_id=instance_id).state["Name"]
- def wait_for_state(self,
- instance_id: str,
- target_state: str,
- check_interval: float) -> None:
+ def wait_for_state(self, instance_id: str, target_state: str, check_interval: float) -> None:
"""
Wait EC2 instance until its state is equal to the target_state.
@@ -77,12 +72,8 @@ class EC2Hook(AwsBaseHook):
:return: None
:rtype: None
"""
- instance_state = self.get_instance_state(
- instance_id=instance_id
- )
+ instance_state = self.get_instance_state(instance_id=instance_id)
while instance_state != target_state:
self.log.info("instance state: %s", instance_state)
time.sleep(check_interval)
- instance_state = self.get_instance_state(
- instance_id=instance_id
- )
+ instance_state = self.get_instance_state(instance_id=instance_id)
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index 001374e..6bad910 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -47,9 +47,7 @@ class EmrHook(AwsBaseHook):
:return: id of the EMR cluster
"""
- response = self.get_conn().list_clusters(
- ClusterStates=cluster_states
- )
+ response = self.get_conn().list_clusters(ClusterStates=cluster_states)
matching_clusters = list(
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py
index 9db925d..dde6362 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -46,19 +46,23 @@ class AwsGlueJobHook(AwsBaseHook):
:param iam_role_name: AWS IAM Role for Glue Job
:type iam_role_name: Optional[str]
"""
+
JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds
- def __init__(self,
- s3_bucket: Optional[str] = None,
- job_name: Optional[str] = None,
- desc: Optional[str] = None,
- concurrent_run_limit: int = 1,
- script_location: Optional[str] = None,
- retry_limit: int = 0,
- num_of_dpus: int = 10,
- region_name: Optional[str] = None,
- iam_role_name: Optional[str] = None,
- *args, **kwargs):
+ def __init__(
+ self,
+ s3_bucket: Optional[str] = None,
+ job_name: Optional[str] = None,
+ desc: Optional[str] = None,
+ concurrent_run_limit: int = 1,
+ script_location: Optional[str] = None,
+ retry_limit: int = 0,
+ num_of_dpus: int = 10,
+ region_name: Optional[str] = None,
+ iam_role_name: Optional[str] = None,
+ *args,
+ **kwargs,
+ ):
self.job_name = job_name
self.desc = desc
self.concurrent_run_limit = concurrent_run_limit
@@ -104,10 +108,7 @@ class AwsGlueJobHook(AwsBaseHook):
try:
job_name = self.get_or_create_glue_job()
- job_run = glue_client.start_job_run(
- JobName=job_name,
- Arguments=script_arguments
- )
+ job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments)
return job_run
except Exception as general_error:
self.log.error("Failed to run aws glue job, error: %s", general_error)
@@ -124,11 +125,7 @@ class AwsGlueJobHook(AwsBaseHook):
:return: State of the Glue job
"""
glue_client = self.get_conn()
- job_run = glue_client.get_job_run(
- JobName=job_name,
- RunId=run_id,
- PredecessorsIncluded=True
- )
+ job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
job_run_state = job_run['JobRun']['JobRunState']
return job_run_state
@@ -157,8 +154,8 @@ class AwsGlueJobHook(AwsBaseHook):
raise AirflowException(job_error_message)
else:
self.log.info(
- "Polling for AWS Glue Job %s current run state with status %s",
- job_name, job_run_state)
+ "Polling for AWS Glue Job %s current run state with status %s", job_name, job_run_state
+ )
time.sleep(self.JOB_POLL_INTERVAL)
def get_or_create_glue_job(self) -> str:
@@ -176,8 +173,7 @@ class AwsGlueJobHook(AwsBaseHook):
self.log.info("Job doesnt exist. Now creating and running AWS Glue Job")
if self.s3_bucket is None:
raise AirflowException(
- 'Could not initialize glue job, '
- 'error: Specify Parameter `s3_bucket`'
+ 'Could not initialize glue job, ' 'error: Specify Parameter `s3_bucket`'
)
s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}'
execution_role = self.get_iam_execution_role()
@@ -190,7 +186,7 @@ class AwsGlueJobHook(AwsBaseHook):
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
MaxRetries=self.retry_limit,
- AllocatedCapacity=self.num_of_dpus
+ AllocatedCapacity=self.num_of_dpus,
)
return create_job_response['Name']
except Exception as general_error:
diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py
index 5a53328..27fc7c1 100644
--- a/airflow/providers/amazon/aws/hooks/glue_catalog.py
+++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py
@@ -36,12 +36,7 @@ class AwsGlueCatalogHook(AwsBaseHook):
def __init__(self, *args, **kwargs):
super().__init__(client_type='glue', *args, **kwargs)
- def get_partitions(self,
- database_name,
- table_name,
- expression='',
- page_size=None,
- max_items=None):
+ def get_partitions(self, database_name, table_name, expression='', page_size=None, max_items=None):
"""
Retrieves the partition values for a table.
@@ -68,10 +63,7 @@ class AwsGlueCatalogHook(AwsBaseHook):
paginator = self.get_conn().get_paginator('get_partitions')
response = paginator.paginate(
- DatabaseName=database_name,
- TableName=table_name,
- Expression=expression,
- PaginationConfig=config
+ DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config
)
partitions = set()
diff --git a/airflow/providers/amazon/aws/hooks/kinesis.py b/airflow/providers/amazon/aws/hooks/kinesis.py
index 04a50f7..1c8480a 100644
--- a/airflow/providers/amazon/aws/hooks/kinesis.py
+++ b/airflow/providers/amazon/aws/hooks/kinesis.py
@@ -45,9 +45,6 @@ class AwsFirehoseHook(AwsBaseHook):
Write batch records to Kinesis Firehose
"""
- response = self.get_conn().put_record_batch(
- DeliveryStreamName=self.delivery_stream,
- Records=records
- )
+ response = self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
return response
diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py
index 2656b7e..a1d9b61 100644
--- a/airflow/providers/amazon/aws/hooks/lambda_function.py
+++ b/airflow/providers/amazon/aws/hooks/lambda_function.py
@@ -42,9 +42,15 @@ class AwsLambdaHook(AwsBaseHook):
:type invocation_type: str
"""
- def __init__(self, function_name,
- log_type='None', qualifier='$LATEST',
- invocation_type='RequestResponse', *args, **kwargs):
+ def __init__(
+ self,
+ function_name,
+ log_type='None',
+ qualifier='$LATEST',
+ invocation_type='RequestResponse',
+ *args,
+ **kwargs,
+ ):
self.function_name = function_name
self.log_type = log_type
self.invocation_type = invocation_type
@@ -61,7 +67,7 @@ class AwsLambdaHook(AwsBaseHook):
InvocationType=self.invocation_type,
LogType=self.log_type,
Payload=payload,
- Qualifier=self.qualifier
+ Qualifier=self.qualifier,
)
return response
diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py
index f8c536a..1abb83d 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -71,11 +71,13 @@ class AwsLogsHook(AwsBaseHook):
else:
token_arg = {}
- response = self.get_conn().get_log_events(logGroupName=log_group,
- logStreamName=log_stream_name,
- startTime=start_time,
- startFromHead=start_from_head,
- **token_arg)
+ response = self.get_conn().get_log_events(
+ logGroupName=log_group,
+ logStreamName=log_stream_name,
+ startTime=start_time,
+ startFromHead=start_from_head,
+ **token_arg,
+ )
events = response['events']
event_count = len(events)
diff --git a/airflow/providers/amazon/aws/hooks/redshift.py b/airflow/providers/amazon/aws/hooks/redshift.py
index 57f59c8..065e975 100644
--- a/airflow/providers/amazon/aws/hooks/redshift.py
+++ b/airflow/providers/amazon/aws/hooks/redshift.py
@@ -47,17 +47,17 @@ class RedshiftHook(AwsBaseHook):
:type cluster_identifier: str
"""
try:
- response = self.get_conn().describe_clusters(
- ClusterIdentifier=cluster_identifier)['Clusters']
+ response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)['Clusters']
return response[0]['ClusterStatus'] if response else None
except self.get_conn().exceptions.ClusterNotFoundFault:
return 'cluster_not_found'
def delete_cluster( # pylint: disable=invalid-name
- self,
- cluster_identifier: str,
- skip_final_cluster_snapshot: bool = True,
- final_cluster_snapshot_identifier: Optional[str] = None):
+ self,
+ cluster_identifier: str,
+ skip_final_cluster_snapshot: bool = True,
+ final_cluster_snapshot_identifier: Optional[str] = None,
+ ):
"""
Delete a cluster and optionally create a snapshot
@@ -73,7 +73,7 @@ class RedshiftHook(AwsBaseHook):
response = self.get_conn().delete_cluster(
ClusterIdentifier=cluster_identifier,
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
- FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier
+ FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
)
return response['Cluster'] if response['Cluster'] else None
@@ -84,9 +84,7 @@ class RedshiftHook(AwsBaseHook):
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
- response = self.get_conn().describe_cluster_snapshots(
- ClusterIdentifier=cluster_identifier
- )
+ response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
if 'Snapshots' not in response:
return None
snapshots = response['Snapshots']
@@ -94,10 +92,7 @@ class RedshiftHook(AwsBaseHook):
snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True)
return snapshots
- def restore_from_cluster_snapshot(
- self,
- cluster_identifier: str,
- snapshot_identifier: str) -> str:
+ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str:
"""
Restores a cluster from its snapshot
@@ -107,15 +102,11 @@ class RedshiftHook(AwsBaseHook):
:type snapshot_identifier: str
"""
response = self.get_conn().restore_from_cluster_snapshot(
- ClusterIdentifier=cluster_identifier,
- SnapshotIdentifier=snapshot_identifier
+ ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
)
return response['Cluster'] if response['Cluster'] else None
- def create_cluster_snapshot(
- self,
- snapshot_identifier: str,
- cluster_identifier: str) -> str:
+ def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: str) -> str:
"""
Creates a snapshot of a cluster
@@ -125,7 +116,6 @@ class RedshiftHook(AwsBaseHook):
:type cluster_identifier: str
"""
response = self.get_conn().create_cluster_snapshot(
- SnapshotIdentifier=snapshot_identifier,
- ClusterIdentifier=cluster_identifier,
+ SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier,
)
return response['Snapshot'] if response['Snapshot'] else None
diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index 976e7eb..319c5ea 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -87,8 +87,9 @@ def unify_bucket_name_and_key(func: T) -> T:
key_name = get_key_name()
if key_name and 'bucket_name' not in bound_args.arguments:
- bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = \
- S3Hook.parse_s3_url(bound_args.arguments[key_name])
+ bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
+ bound_args.arguments[key_name]
+ )
return func(*bound_args.args, **bound_args.kwargs)
@@ -161,9 +162,7 @@ class S3Hook(AwsBaseHook):
return s3_resource.Bucket(bucket_name)
@provide_bucket_name
- def create_bucket(self,
- bucket_name: Optional[str] = None,
- region_name: Optional[str] = None) -> None:
+ def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional[str] = None) -> None:
"""
Creates an Amazon S3 bucket.
@@ -177,16 +176,12 @@ class S3Hook(AwsBaseHook):
if region_name == 'us-east-1':
self.get_conn().create_bucket(Bucket=bucket_name)
else:
- self.get_conn().create_bucket(Bucket=bucket_name,
- CreateBucketConfiguration={
- 'LocationConstraint': region_name
- })
+ self.get_conn().create_bucket(
+ Bucket=bucket_name, CreateBucketConfiguration={'LocationConstraint': region_name}
+ )
@provide_bucket_name
- def check_for_prefix(self,
- prefix: str,
- delimiter: str,
- bucket_name: Optional[str] = None) -> bool:
+ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[str] = None) -> bool:
"""
Checks that a prefix exists in a bucket
@@ -206,12 +201,14 @@ class S3Hook(AwsBaseHook):
return False if plist is None else prefix in plist
@provide_bucket_name
- def list_prefixes(self,
- bucket_name: Optional[str] = None,
- prefix: Optional[str] = None,
- delimiter: Optional[str] = None,
- page_size: Optional[int] = None,
- max_items: Optional[int] = None) -> Optional[list]:
+ def list_prefixes(
+ self,
+ bucket_name: Optional[str] = None,
+ prefix: Optional[str] = None,
+ delimiter: Optional[str] = None,
+ page_size: Optional[int] = None,
+ max_items: Optional[int] = None,
+ ) -> Optional[list]:
"""
Lists prefixes in a bucket under prefix
@@ -236,10 +233,9 @@ class S3Hook(AwsBaseHook):
}
paginator = self.get_conn().get_paginator('list_objects_v2')
- response = paginator.paginate(Bucket=bucket_name,
- Prefix=prefix,
- Delimiter=delimiter,
- PaginationConfig=config)
+ response = paginator.paginate(
+ Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
+ )
has_results = False
prefixes = []
@@ -254,12 +250,14 @@ class S3Hook(AwsBaseHook):
return None
@provide_bucket_name
- def list_keys(self,
- bucket_name: Optional[str] = None,
- prefix: Optional[str] = None,
- delimiter: Optional[str] = None,
- page_size: Optional[int] = None,
- max_items: Optional[int] = None) -> Optional[list]:
+ def list_keys(
+ self,
+ bucket_name: Optional[str] = None,
+ prefix: Optional[str] = None,
+ delimiter: Optional[str] = None,
+ page_size: Optional[int] = None,
+ max_items: Optional[int] = None,
+ ) -> Optional[list]:
"""
Lists keys in a bucket under prefix and not containing delimiter
@@ -284,10 +282,9 @@ class S3Hook(AwsBaseHook):
}
paginator = self.get_conn().get_paginator('list_objects_v2')
- response = paginator.paginate(Bucket=bucket_name,
- Prefix=prefix,
- Delimiter=delimiter,
- PaginationConfig=config)
+ response = paginator.paginate(
+ Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
+ )
has_results = False
keys = []
@@ -359,13 +356,15 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def select_key(self,
- key: str,
- bucket_name: Optional[str] = None,
- expression: Optional[str] = None,
- expression_type: Optional[str] = None,
- input_serialization: Optional[Dict[str, Any]] = None,
- output_serialization: Optional[Dict[str, Any]] = None) -> str:
+ def select_key(
+ self,
+ key: str,
+ bucket_name: Optional[str] = None,
+ expression: Optional[str] = None,
+ expression_type: Optional[str] = None,
+ input_serialization: Optional[Dict[str, Any]] = None,
+ output_serialization: Optional[Dict[str, Any]] = None,
+ ) -> str:
"""
Reads a key with S3 Select.
@@ -402,18 +401,18 @@ class S3Hook(AwsBaseHook):
Expression=expression,
ExpressionType=expression_type,
InputSerialization=input_serialization,
- OutputSerialization=output_serialization)
+ OutputSerialization=output_serialization,
+ )
- return ''.join(event['Records']['Payload'].decode('utf-8')
- for event in response['Payload']
- if 'Records' in event)
+ return ''.join(
+ event['Records']['Payload'].decode('utf-8') for event in response['Payload'] if 'Records' in event
+ )
@provide_bucket_name
@unify_bucket_name_and_key
- def check_for_wildcard_key(self,
- wildcard_key: str,
- bucket_name: Optional[str] = None,
- delimiter: str = '') -> bool:
+ def check_for_wildcard_key(
+ self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ''
+ ) -> bool:
"""
Checks that a key matching a wildcard expression exists in a bucket
@@ -426,16 +425,16 @@ class S3Hook(AwsBaseHook):
:return: True if a key exists and False if not.
:rtype: bool
"""
- return self.get_wildcard_key(wildcard_key=wildcard_key,
- bucket_name=bucket_name,
- delimiter=delimiter) is not None
+ return (
+ self.get_wildcard_key(wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter)
+ is not None
+ )
@provide_bucket_name
@unify_bucket_name_and_key
- def get_wildcard_key(self,
- wildcard_key: str,
- bucket_name: Optional[str] = None,
- delimiter: str = '') -> S3Transfer:
+ def get_wildcard_key(
+ self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ''
+ ) -> S3Transfer:
"""
Returns a boto3.s3.Object object matching the wildcard expression
@@ -459,14 +458,16 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def load_file(self,
- filename: str,
- key: str,
- bucket_name: Optional[str] = None,
- replace: bool = False,
- encrypt: bool = False,
- gzip: bool = False,
- acl_policy: Optional[str] = None) -> None:
+ def load_file(
+ self,
+ filename: str,
+ key: str,
+ bucket_name: Optional[str] = None,
+ replace: bool = False,
+ encrypt: bool = False,
+ gzip: bool = False,
+ acl_policy: Optional[str] = None,
+ ) -> None:
"""
Loads a local file to S3
@@ -511,14 +512,16 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def load_string(self,
- string_data: str,
- key: str,
- bucket_name: Optional[str] = None,
- replace: bool = False,
- encrypt: bool = False,
- encoding: Optional[str] = None,
- acl_policy: Optional[str] = None) -> None:
+ def load_string(
+ self,
+ string_data: str,
+ key: str,
+ bucket_name: Optional[str] = None,
+ replace: bool = False,
+ encrypt: bool = False,
+ encoding: Optional[str] = None,
+ acl_policy: Optional[str] = None,
+ ) -> None:
"""
Loads a string to S3
@@ -552,13 +555,15 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def load_bytes(self,
- bytes_data: bytes,
- key: str,
- bucket_name: Optional[str] = None,
- replace: bool = False,
- encrypt: bool = False,
- acl_policy: Optional[str] = None) -> None:
+ def load_bytes(
+ self,
+ bytes_data: bytes,
+ key: str,
+ bucket_name: Optional[str] = None,
+ replace: bool = False,
+ encrypt: bool = False,
+ acl_policy: Optional[str] = None,
+ ) -> None:
"""
Loads bytes to S3
@@ -587,13 +592,15 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def load_file_obj(self,
- file_obj: BytesIO,
- key: str,
- bucket_name: Optional[str] = None,
- replace: bool = False,
- encrypt: bool = False,
- acl_policy: Optional[str] = None) -> None:
+ def load_file_obj(
+ self,
+ file_obj: BytesIO,
+ key: str,
+ bucket_name: Optional[str] = None,
+ replace: bool = False,
+ encrypt: bool = False,
+ acl_policy: Optional[str] = None,
+ ) -> None:
"""
Loads a file object to S3
@@ -615,13 +622,15 @@ class S3Hook(AwsBaseHook):
"""
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
- def _upload_file_obj(self,
- file_obj: BytesIO,
- key: str,
- bucket_name: Optional[str] = None,
- replace: bool = False,
- encrypt: bool = False,
- acl_policy: Optional[str] = None) -> None:
+ def _upload_file_obj(
+ self,
+ file_obj: BytesIO,
+ key: str,
+ bucket_name: Optional[str] = None,
+ replace: bool = False,
+ encrypt: bool = False,
+ acl_policy: Optional[str] = None,
+ ) -> None:
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
@@ -634,13 +643,15 @@ class S3Hook(AwsBaseHook):
client = self.get_conn()
client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args)
- def copy_object(self,
- source_bucket_key: str,
- dest_bucket_key: str,
- source_bucket_name: Optional[str] = None,
- dest_bucket_name: Optional[str] = None,
- source_version_id: Optional[str] = None,
- acl_policy: Optional[str] = None) -> None:
+ def copy_object(
+ self,
+ source_bucket_key: str,
+ dest_bucket_key: str,
+ source_bucket_name: Optional[str] = None,
+ dest_bucket_name: Optional[str] = None,
+ source_version_id: Optional[str] = None,
+ acl_policy: Optional[str] = None,
+ ) -> None:
"""
Creates a copy of an object that is already stored in S3.
@@ -679,26 +690,27 @@ class S3Hook(AwsBaseHook):
else:
parsed_url = urlparse(dest_bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
- raise AirflowException('If dest_bucket_name is provided, ' +
- 'dest_bucket_key should be relative path ' +
- 'from root level, rather than a full s3:// url')
+ raise AirflowException(
+ 'If dest_bucket_name is provided, '
+ + 'dest_bucket_key should be relative path '
+ + 'from root level, rather than a full s3:// url'
+ )
if source_bucket_name is None:
source_bucket_name, source_bucket_key = self.parse_s3_url(source_bucket_key)
else:
parsed_url = urlparse(source_bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
- raise AirflowException('If source_bucket_name is provided, ' +
- 'source_bucket_key should be relative path ' +
- 'from root level, rather than a full s3:// url')
-
- copy_source = {'Bucket': source_bucket_name,
- 'Key': source_bucket_key,
- 'VersionId': source_version_id}
- response = self.get_conn().copy_object(Bucket=dest_bucket_name,
- Key=dest_bucket_key,
- CopySource=copy_source,
- ACL=acl_policy)
+ raise AirflowException(
+ 'If source_bucket_name is provided, '
+ + 'source_bucket_key should be relative path '
+ + 'from root level, rather than a full s3:// url'
+ )
+
+ copy_source = {'Bucket': source_bucket_name, 'Key': source_bucket_key, 'VersionId': source_version_id}
+ response = self.get_conn().copy_object(
+ Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy
+ )
return response
@provide_bucket_name
@@ -717,9 +729,7 @@ class S3Hook(AwsBaseHook):
bucket_keys = self.list_keys(bucket_name=bucket_name)
if bucket_keys:
self.delete_objects(bucket=bucket_name, keys=bucket_keys)
- self.conn.delete_bucket(
- Bucket=bucket_name
- )
+ self.conn.delete_bucket(Bucket=bucket_name)
def delete_objects(self, bucket: str, keys: Union[str, list]) -> None:
"""
@@ -745,10 +755,7 @@ class S3Hook(AwsBaseHook):
# For details see:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
for chunk in chunks(keys, chunk_size=1000):
- response = s3.delete_objects(
- Bucket=bucket,
- Delete={"Objects": [{"Key": k} for k in chunk]}
- )
+ response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
@@ -757,10 +764,9 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def download_file(self,
- key: str,
- bucket_name: Optional[str] = None,
- local_path: Optional[str] = None) -> str:
+ def download_file(
+ self, key: str, bucket_name: Optional[str] = None, local_path: Optional[str] = None
+ ) -> str:
"""
Downloads a file from the S3 location to the local file system.
@@ -786,11 +792,13 @@ class S3Hook(AwsBaseHook):
return local_tmp_file.name
- def generate_presigned_url(self,
- client_method: str,
- params: Optional[dict] = None,
- expires_in: int = 3600,
- http_method: Optional[str] = None) -> Optional[str]:
+ def generate_presigned_url(
+ self,
+ client_method: str,
+ params: Optional[dict] = None,
+ expires_in: int = 3600,
+ http_method: Optional[str] = None,
+ ) -> Optional[str]:
"""
Generate a presigned url given a client, its method, and arguments
@@ -810,10 +818,9 @@ class S3Hook(AwsBaseHook):
s3_client = self.get_conn()
try:
- return s3_client.generate_presigned_url(ClientMethod=client_method,
- Params=params,
- ExpiresIn=expires_in,
- HttpMethod=http_method)
+ return s3_client.generate_presigned_url(
+ ClientMethod=client_method, Params=params, ExpiresIn=expires_in, HttpMethod=http_method
+ )
except ClientError as e:
self.log.error(e.response["Error"]["Message"])
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index bb65a55..fb5aed6 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -38,6 +38,7 @@ class LogState:
Enum-style class holding all possible states of CloudWatch log streams.
https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState
"""
+
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
@@ -77,12 +78,16 @@ def secondary_training_status_changed(current_job_description, prev_job_descript
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False
- prev_job_secondary_status_transitions = prev_job_description.get('SecondaryStatusTransitions') \
- if prev_job_description is not None else None
+ prev_job_secondary_status_transitions = (
+ prev_job_description.get('SecondaryStatusTransitions') if prev_job_description is not None else None
+ )
- last_message = prev_job_secondary_status_transitions[-1]['StatusMessage'] \
- if prev_job_secondary_status_transitions is not None \
- and len(prev_job_secondary_status_transitions) > 0 else ''
+ last_message = (
+ prev_job_secondary_status_transitions[-1]['StatusMessage']
+ if prev_job_secondary_status_transitions is not None
+ and len(prev_job_secondary_status_transitions) > 0
+ else ''
+ )
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
@@ -101,18 +106,28 @@ def secondary_training_status_message(job_description, prev_description):
:return: Job status string to be printed.
"""
- if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
- or len(job_description.get('SecondaryStatusTransitions')) == 0:
+ if (
+ job_description is None
+ or job_description.get('SecondaryStatusTransitions') is None
+ or len(job_description.get('SecondaryStatusTransitions')) == 0
+ ):
return ''
- prev_description_secondary_transitions = prev_description.get('SecondaryStatusTransitions')\
- if prev_description is not None else None
- prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])\
- if prev_description_secondary_transitions is not None else 0
+ prev_description_secondary_transitions = (
+ prev_description.get('SecondaryStatusTransitions') if prev_description is not None else None
+ )
+ prev_transitions_num = (
+ len(prev_description['SecondaryStatusTransitions'])
+ if prev_description_secondary_transitions is not None
+ else 0
+ )
current_transitions = job_description['SecondaryStatusTransitions']
- transitions_to_print = current_transitions[-1:] if len(current_transitions) == prev_transitions_num else \
- current_transitions[prev_transitions_num - len(current_transitions):]
+ transitions_to_print = (
+ current_transitions[-1:]
+ if len(current_transitions) == prev_transitions_num
+ else current_transitions[prev_transitions_num - len(current_transitions) :]
+ )
status_strs = []
for transition in transitions_to_print:
@@ -123,7 +138,7 @@ def secondary_training_status_message(job_description, prev_description):
return '\n'.join(status_strs)
-class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
+class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
"""
Interact with Amazon SageMaker.
@@ -133,9 +148,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
.. seealso::
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
+
non_terminal_states = {'InProgress', 'Stopping'}
- endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating',
- 'RollingBack', 'Deleting'}
+ endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', 'RollingBack', 'Deleting'}
failed_states = {'Failed'}
def __init__(self, *args, **kwargs):
@@ -183,11 +198,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
self.s3_hook.create_bucket(bucket_name=op['Bucket'])
for op in upload_ops:
if op['Tar']:
- self.tar_and_s3_upload(op['Path'], op['Key'],
- op['Bucket'])
+ self.tar_and_s3_upload(op['Path'], op['Key'], op['Bucket'])
else:
- self.s3_hook.load_file(op['Path'], op['Key'],
- op['Bucket'])
+ self.s3_hook.load_file(op['Path'], op['Key'], op['Bucket'])
def check_s3_url(self, s3url):
"""
@@ -199,17 +212,18 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
"""
bucket, key = S3Hook.parse_s3_url(s3url)
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
- raise AirflowException(
- "The input S3 Bucket {} does not exist ".format(bucket))
- if key and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)\
- and not self.s3_hook.check_for_prefix(
- prefix=key, bucket_name=bucket, delimiter='/'):
+ raise AirflowException("The input S3 Bucket {} does not exist ".format(bucket))
+ if (
+ key
+ and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)
+ and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter='/')
+ ):
# check if s3 key exists in the case user provides a single file
# or if s3 prefix exists in the case user provides multiple files in
# a prefix
- raise AirflowException("The input S3 Key "
- "or Prefix {} does not exist in the Bucket {}"
- .format(s3url, bucket))
+ raise AirflowException(
+ "The input S3 Key " "or Prefix {} does not exist in the Bucket {}".format(s3url, bucket)
+ )
return True
def check_training_config(self, training_config):
@@ -240,10 +254,12 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
This method is deprecated.
Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.
"""
- warnings.warn("Method `get_log_conn` has been deprecated. "
- "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.",
- category=DeprecationWarning,
- stacklevel=2)
+ warnings.warn(
+ "Method `get_log_conn` has been deprecated. "
+ "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
return self.logs_hook.get_conn()
@@ -253,11 +269,13 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
Please use
:py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.
"""
- warnings.warn("Method `log_stream` has been deprecated. "
- "Please use "
- "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.",
- category=DeprecationWarning,
- stacklevel=2)
+ warnings.warn(
+ "Method `log_stream` has been deprecated. "
+ "Please use "
+ "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
@@ -277,8 +295,10 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
:return: A tuple of (stream number, cloudwatch log event).
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
- event_iters = [self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
- for s in streams]
+ event_iters = [
+ self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
+ for s in streams
+ ]
events = []
for event_stream in event_iters:
if not event_stream:
@@ -297,8 +317,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
except StopIteration:
events[i] = None
- def create_training_job(self, config, wait_for_completion=True, print_log=True,
- check_interval=30, max_ingestion_time=None):
+ def create_training_job(
+ self, config, wait_for_completion=True, print_log=True, check_interval=30, max_ingestion_time=None
+ ):
"""
Create a training job
@@ -320,28 +341,31 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().create_training_job(**config)
if print_log:
- self.check_training_status_with_log(config['TrainingJobName'],
- self.non_terminal_states,
- self.failed_states,
- wait_for_completion,
- check_interval, max_ingestion_time
- )
+ self.check_training_status_with_log(
+ config['TrainingJobName'],
+ self.non_terminal_states,
+ self.failed_states,
+ wait_for_completion,
+ check_interval,
+ max_ingestion_time,
+ )
elif wait_for_completion:
- describe_response = self.check_status(config['TrainingJobName'],
- 'TrainingJobStatus',
- self.describe_training_job,
- check_interval, max_ingestion_time
- )
-
- billable_time = \
- (describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']) * \
- describe_response['ResourceConfig']['InstanceCount']
+ describe_response = self.check_status(
+ config['TrainingJobName'],
+ 'TrainingJobStatus',
+ self.describe_training_job,
+ check_interval,
+ max_ingestion_time,
+ )
+
+ billable_time = (
+ describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']
+ ) * describe_response['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1)
return response
- def create_tuning_job(self, config, wait_for_completion=True,
- check_interval=30, max_ingestion_time=None):
+ def create_tuning_job(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None):
"""
Create a tuning job
@@ -363,15 +387,18 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
if wait_for_completion:
- self.check_status(config['HyperParameterTuningJobName'],
- 'HyperParameterTuningJobStatus',
- self.describe_tuning_job,
- check_interval, max_ingestion_time
- )
+ self.check_status(
+ config['HyperParameterTuningJobName'],
+ 'HyperParameterTuningJobStatus',
+ self.describe_tuning_job,
+ check_interval,
+ max_ingestion_time,
+ )
return response
- def create_transform_job(self, config, wait_for_completion=True,
- check_interval=30, max_ingestion_time=None):
+ def create_transform_job(
+ self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None
+ ):
"""
Create a transform job
@@ -393,15 +420,18 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().create_transform_job(**config)
if wait_for_completion:
- self.check_status(config['TransformJobName'],
- 'TransformJobStatus',
- self.describe_transform_job,
- check_interval, max_ingestion_time
- )
+ self.check_status(
+ config['TransformJobName'],
+ 'TransformJobStatus',
+ self.describe_transform_job,
+ check_interval,
+ max_ingestion_time,
+ )
return response
- def create_processing_job(self, config, wait_for_completion=True,
- check_interval=30, max_ingestion_time=None):
+ def create_processing_job(
+ self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None
+ ):
"""
Create a processing job
@@ -421,11 +451,13 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().create_processing_job(**config)
if wait_for_completion:
- self.check_status(config['ProcessingJobName'],
- 'ProcessingJobStatus',
- self.describe_processing_job,
- check_interval, max_ingestion_time
- )
+ self.check_status(
+ config['ProcessingJobName'],
+ 'ProcessingJobStatus',
+ self.describe_processing_job,
+ check_interval,
+ max_ingestion_time,
+ )
return response
def create_model(self, config):
@@ -450,8 +482,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
return self.get_conn().create_endpoint_config(**config)
- def create_endpoint(self, config, wait_for_completion=True,
- check_interval=30, max_ingestion_time=None):
+ def create_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None):
"""
Create an endpoint
@@ -471,16 +502,17 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().create_endpoint(**config)
if wait_for_completion:
- self.check_status(config['EndpointName'],
- 'EndpointStatus',
- self.describe_endpoint,
- check_interval, max_ingestion_time,
- non_terminal_states=self.endpoint_non_terminal_states
- )
+ self.check_status(
+ config['EndpointName'],
+ 'EndpointStatus',
+ self.describe_endpoint,
+ check_interval,
+ max_ingestion_time,
+ non_terminal_states=self.endpoint_non_terminal_states,
+ )
return response
- def update_endpoint(self, config, wait_for_completion=True,
- check_interval=30, max_ingestion_time=None):
+ def update_endpoint(self, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None):
"""
Update an endpoint
@@ -500,12 +532,14 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = self.get_conn().update_endpoint(**config)
if wait_for_completion:
- self.check_status(config['EndpointName'],
- 'EndpointStatus',
- self.describe_endpoint,
- check_interval, max_ingestion_time,
- non_terminal_states=self.endpoint_non_terminal_states
- )
+ self.check_status(
+ config['EndpointName'],
+ 'EndpointStatus',
+ self.describe_endpoint,
+ check_interval,
+ max_ingestion_time,
+ non_terminal_states=self.endpoint_non_terminal_states,
+ )
return response
def describe_training_job(self, name):
@@ -519,9 +553,16 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
return self.get_conn().describe_training_job(TrainingJobName=name)
- def describe_training_job_with_log(self, job_name, positions, stream_names,
- instance_count, state, last_description,
- last_describe_job_call):
+ def describe_training_job_with_log(
+ self,
+ job_name,
+ positions,
+ stream_names,
+ instance_count,
+ state,
+ last_description,
+ last_describe_job_call,
+ ):
"""
Return the training job info associated with job_name and print CloudWatch logs
"""
@@ -536,11 +577,12 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
logGroupName=log_group,
logStreamNamePrefix=job_name + '/',
orderBy='LogStreamName',
- limit=instance_count
+ limit=instance_count,
)
stream_names = [s['logStreamName'] for s in streams['logStreams']]
- positions.update([(s, Position(timestamp=0, skip=0))
- for s in stream_names if s not in positions])
+ positions.update(
+ [(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions]
+ )
except logs_conn.exceptions.ResourceNotFoundException:
# On the very first training job run on an account, there's no log group until
# the container starts logging, so ignore any errors thrown about that
@@ -638,10 +680,9 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
return self.get_conn().describe_endpoint(EndpointName=name)
- def check_status(self, job_name, key,
- describe_function, check_interval,
- max_ingestion_time,
- non_terminal_states=None):
+ def check_status(
+ self, job_name, key, describe_function, check_interval, max_ingestion_time, non_terminal_states=None
+ ):
"""
Check status of a SageMaker job
@@ -677,8 +718,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
try:
response = describe_function(job_name)
status = response[key]
- self.log.info('Job still running for %s seconds... '
- 'current status is %s', sec, status)
+ self.log.info('Job still running for %s seconds... ' 'current status is %s', sec, status)
except KeyError:
raise AirflowException('Could not get status of the SageMaker job')
except ClientError:
@@ -699,8 +739,15 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
response = describe_function(job_name)
return response
- def check_training_status_with_log(self, job_name, non_terminal_states, failed_states,
- wait_for_completion, check_interval, max_ingestion_time):
+ def check_training_status_with_log(
+ self,
+ job_name,
+ non_terminal_states,
+ failed_states,
+ wait_for_completion,
+ check_interval,
+ max_ingestion_time,
+ ):
"""
Display the logs for a given training job, optionally tailing them until the
job is complete.
@@ -730,7 +777,7 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
status = description['TrainingJobStatus']
stream_names = [] # The list of log streams
- positions = {} # The current position in each stream, map of stream name -> position
+ positions = {} # The current position in each stream, map of stream name -> position
job_already_completed = status not in non_terminal_states
@@ -763,10 +810,15 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
time.sleep(check_interval)
sec += check_interval
- state, last_description, last_describe_job_call = \
- self.describe_training_job_with_log(job_name, positions, stream_names,
- instance_count, state, last_description,
- last_describe_job_call)
+ state, last_description, last_describe_job_call = self.describe_training_job_with_log(
+ job_name,
+ positions,
+ stream_names,
+ instance_count,
+ state,
+ last_description,
+ last_describe_job_call,
+ )
if state == LogState.COMPLETE:
break
@@ -779,13 +831,14 @@ class SageMakerHook(AwsBaseHook): # pylint: disable=too-many-public-methods
if status in failed_states:
reason = last_description.get('FailureReason', '(No reason provided)')
raise AirflowException('Error training {}: {} Reason: {}'.format(job_name, status, reason))
- billable_time = (last_description['TrainingEndTime'] - last_description['TrainingStartTime']) \
- * instance_count
+ billable_time = (
+ last_description['TrainingEndTime'] - last_description['TrainingStartTime']
+ ) * instance_count
self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1)
def list_training_jobs(
self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
- ) -> List[Dict]: # noqa: D402
+ ) -> List[Dict]: # noqa: D402
"""
This method wraps boto3's list_training_jobs(). The training job name and max results are configurable
via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
diff --git a/airflow/providers/amazon/aws/hooks/ses.py b/airflow/providers/amazon/aws/hooks/ses.py
index 2ee8171..3844b71 100644
--- a/airflow/providers/amazon/aws/hooks/ses.py
+++ b/airflow/providers/amazon/aws/hooks/ses.py
@@ -52,7 +52,7 @@ class SESHook(AwsBaseHook):
mime_charset: str = 'utf-8',
reply_to: Optional[str] = None,
return_path: Optional[str] = None,
- custom_headers: Optional[Dict[str, Any]] = None
+ custom_headers: Optional[Dict[str, Any]] = None,
) -> dict:
"""
Send email using Amazon Simple Email Service
diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py
index f0b0d5b..e5045bb 100644
--- a/airflow/providers/amazon/aws/hooks/sns.py
+++ b/airflow/providers/amazon/aws/hooks/sns.py
@@ -33,8 +33,9 @@ def _get_message_attribute(o):
return {'DataType': 'Number', 'StringValue': str(o)}
if hasattr(o, '__iter__'):
return {'DataType': 'String.Array', 'StringValue': json.dumps(o)}
- raise TypeError('Values in MessageAttributes must be one of bytes, str, int, float, or iterable; '
- f'got {type(o)}')
+ raise TypeError(
+ 'Values in MessageAttributes must be one of bytes, str, int, float, or iterable; ' f'got {type(o)}'
+ )
class AwsSnsHook(AwsBaseHook):
@@ -74,9 +75,7 @@ class AwsSnsHook(AwsBaseHook):
publish_kwargs = {
'TargetArn': target_arn,
'MessageStructure': 'json',
- 'Message': json.dumps({
- 'default': message
- }),
+ 'Message': json.dumps({'default': message}),
}
# Construct args this way because boto3 distinguishes from missing args and those set to None
diff --git a/airflow/providers/amazon/aws/hooks/sqs.py b/airflow/providers/amazon/aws/hooks/sqs.py
index 849979b..6c43f7f 100644
--- a/airflow/providers/amazon/aws/hooks/sqs.py
+++ b/airflow/providers/amazon/aws/hooks/sqs.py
@@ -70,7 +70,9 @@ class SQSHook(AwsBaseHook):
For details of the returned value see :py:meth:`botocore.client.SQS.send_message`
:rtype: dict
"""
- return self.get_conn().send_message(QueueUrl=queue_url,
- MessageBody=message_body,
- DelaySeconds=delay_seconds,
- MessageAttributes=message_attributes or {})
+ return self.get_conn().send_message(
+ QueueUrl=queue_url,
+ MessageBody=message_body,
+ DelaySeconds=delay_seconds,
+ MessageAttributes=message_attributes or {},
+ )
diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py
index f0e1040..d83d1af 100644
--- a/airflow/providers/amazon/aws/hooks/step_function.py
+++ b/airflow/providers/amazon/aws/hooks/step_function.py
@@ -35,8 +35,12 @@ class StepFunctionHook(AwsBaseHook):
def __init__(self, region_name=None, *args, **kwargs):
super().__init__(client_type='stepfunctions', *args, **kwargs)
- def start_execution(self, state_machine_arn: str, name: Optional[str] = None,
- state_machine_input: Union[dict, str, None] = None) -> str:
+ def start_execution(
+ self,
+ state_machine_arn: str,
+ name: Optional[str] = None,
+ state_machine_input: Union[dict, str, None] = None,
+ ) -> str:
"""
Start Execution of the State Machine.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.start_execution
@@ -50,9 +54,7 @@ class StepFunctionHook(AwsBaseHook):
:return: Execution ARN
:rtype: str
"""
- execution_args = {
- 'stateMachineArn': state_machine_arn
- }
+ execution_args = {'stateMachineArn': state_machine_arn}
if name is not None:
execution_args['name'] = name
if state_machine_input is not None:
diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index 2d22452..7d4e3a0 100644
--- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -38,6 +38,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
:param filename_template: template for file name (local storage) or log stream name (remote)
:type filename_template: str
"""
+
def __init__(self, base_log_folder, log_group_arn, filename_template):
super().__init__(base_log_folder, filename_template)
split_arn = log_group_arn.split(':')
@@ -55,12 +56,14 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
try:
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+
return AwsLogsHook(aws_conn_id=remote_conn_id, region_name=self.region_name)
except Exception: # pylint: disable=broad-except
self.log.error(
'Could not create an AwsLogsHook with connection id "%s". '
'Please make sure that airflow[aws] is installed and '
- 'the Cloudwatch logs connection exists.', remote_conn_id
+ 'the Cloudwatch logs connection exists.',
+ remote_conn_id,
)
def _render_filename(self, ti, try_number):
@@ -72,7 +75,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
self.handler = watchtower.CloudWatchLogHandler(
log_group=self.log_group,
stream_name=self._render_filename(ti, ti.try_number),
- boto3_session=self.hook.get_session(self.region_name)
+ boto3_session=self.hook.get_session(self.region_name),
)
def close(self):
@@ -93,9 +96,12 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
def _read(self, task_instance, try_number, metadata=None):
stream_name = self._render_filename(task_instance, try_number)
- return '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n'.format(
- self.log_group, stream_name, self.get_cloudwatch_logs(stream_name=stream_name)
- ), {'end_of_log': True}
+ return (
+ '*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n'.format(
+ self.log_group, stream_name, self.get_cloudwatch_logs(stream_name=stream_name)
+ ),
+ {'end_of_log': True},
+ )
def get_cloudwatch_logs(self, stream_name):
"""
diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py
index b13b7cd..00f52d1 100644
--- a/airflow/providers/amazon/aws/log/s3_task_handler.py
+++ b/airflow/providers/amazon/aws/log/s3_task_handler.py
@@ -30,6 +30,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
task instance logs. It extends airflow FileTaskHandler and
uploads to and reads from S3 remote storage.
"""
+
def __init__(self, base_log_folder, s3_log_folder, filename_template):
super().__init__(base_log_folder, filename_template)
self.remote_base = s3_log_folder
@@ -46,12 +47,14 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
try:
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+
return S3Hook(remote_conn_id)
except Exception: # pylint: disable=broad-except
self.log.exception(
'Could not create an S3Hook with connection id "%s". '
'Please make sure that airflow[aws] is installed and '
- 'the S3 connection exists.', remote_conn_id
+ 'the S3 connection exists.',
+ remote_conn_id,
)
def set_context(self, ti):
@@ -115,8 +118,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.s3_read(remote_loc, return_error=True)
- log = '*** Reading remote log from {}.\n{}\n'.format(
- remote_loc, remote_log)
+ log = '*** Reading remote log from {}.\n{}\n'.format(remote_loc, remote_log)
return log, {'end_of_log': True}
else:
return super()._read(ti, try_number)
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 4d734d0..2039fe3 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -54,11 +54,12 @@ class AWSAthenaOperator(BaseOperator):
ui_color = '#44b5e2'
template_fields = ('query', 'database', 'output_location')
- template_ext = ('.sql', )
+ template_ext = ('.sql',)
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self, *,
+ self,
+ *,
query: str,
database: str,
output_location: str,
@@ -69,7 +70,7 @@ class AWSAthenaOperator(BaseOperator):
result_configuration: Optional[Dict[str, Any]] = None,
sleep_time: int = 30,
max_tries: Optional[int] = None,
- **kwargs: Any
+ **kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.query = query
@@ -95,21 +96,29 @@ class AWSAthenaOperator(BaseOperator):
"""
self.query_execution_context['Database'] = self.database
self.result_configuration['OutputLocation'] = self.output_location
- self.query_execution_id = self.hook.run_query(self.query, self.query_execution_context,
- self.result_configuration, self.client_request_token,
- self.workgroup)
+ self.query_execution_id = self.hook.run_query(
+ self.query,
+ self.query_execution_context,
+ self.result_configuration,
+ self.client_request_token,
+ self.workgroup,
+ )
query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries)
if query_status in AWSAthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
raise Exception(
- 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'
- .format(query_status, self.query_execution_id, error_message))
+ 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
+ query_status, self.query_execution_id, error_message
+ )
+ )
elif not query_status or query_status in AWSAthenaHook.INTERMEDIATE_STATES:
raise Exception(
'Final state of Athena job is {}. '
- 'Max tries of poll status exceeded, query_execution_id is {}.'
- .format(query_status, self.query_execution_id))
+ 'Max tries of poll status exceeded, query_execution_id is {}.'.format(
+ query_status, self.query_execution_id
+ )
+ )
return self.query_execution_id
@@ -119,9 +128,7 @@ class AWSAthenaOperator(BaseOperator):
"""
if self.query_execution_id:
self.log.info('⚰️⚰️⚰️ Received a kill Signal. Time to Die')
- self.log.info(
- 'Stopping Query with executionId - %s', self.query_execution_id
- )
+ self.log.info('Stopping Query with executionId - %s', self.query_execution_id)
response = self.hook.stop_query(self.query_execution_id)
http_status_code = None
try:
diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index c865ade..aabe307 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -99,7 +99,8 @@ class AwsBatchOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
job_name,
job_definition,
job_queue,
@@ -141,9 +142,7 @@ class AwsBatchOperator(BaseOperator):
self.monitor_job(context)
def on_kill(self):
- response = self.hook.client.terminate_job(
- jobId=self.job_id, reason="Task killed by the user"
- )
+ response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response)
def submit_job(self, context: Dict): # pylint: disable=unused-argument
@@ -153,9 +152,7 @@ class AwsBatchOperator(BaseOperator):
:raises: AirflowException
"""
self.log.info(
- "Running AWS Batch job - job definition: %s - on queue %s",
- self.job_definition,
- self.job_queue,
+ "Running AWS Batch job - job definition: %s - on queue %s", self.job_definition, self.job_queue,
)
self.log.info("AWS Batch job - container overrides: %s", self.overrides)
diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py
index f0dc0c4..d6c9bb0 100644
--- a/airflow/providers/amazon/aws/operators/cloud_formation.py
+++ b/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -39,17 +39,13 @@ class CloudFormationCreateStackOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
+
template_fields: List[str] = ['stack_name']
template_ext = ()
ui_color = '#6b9659'
@apply_defaults
- def __init__(
- self, *,
- stack_name,
- params,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, stack_name, params, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.params = params
@@ -76,18 +72,14 @@ class CloudFormationDeleteStackOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
+
template_fields: List[str] = ['stack_name']
template_ext = ()
ui_color = '#1d472b'
ui_fgcolor = '#FFF'
@apply_defaults
- def __init__(
- self, *,
- stack_name,
- params=None,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, stack_name, params=None, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.params = params or {}
self.stack_name = stack_name
diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py
index 944b6e5..681eb9c 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -101,13 +101,14 @@ class AWSDataSyncOperator(BaseOperator):
"create_source_location_kwargs",
"create_destination_location_kwargs",
"update_task_kwargs",
- "task_execution_kwargs"
+ "task_execution_kwargs",
)
ui_color = "#44b5e2"
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
aws_conn_id="aws_default",
wait_interval_seconds=5,
task_arn=None,
@@ -121,7 +122,7 @@ class AWSDataSyncOperator(BaseOperator):
update_task_kwargs=None,
task_execution_kwargs=None,
delete_task_after_execution=False,
- **kwargs
+ **kwargs,
):
super().__init__(**kwargs)
@@ -181,8 +182,7 @@ class AWSDataSyncOperator(BaseOperator):
"""
if not self.hook:
self.hook = AWSDataSyncHook(
- aws_conn_id=self.aws_conn_id,
- wait_interval_seconds=self.wait_interval_seconds,
+ aws_conn_id=self.aws_conn_id, wait_interval_seconds=self.wait_interval_seconds,
)
return self.hook
@@ -194,16 +194,14 @@ class AWSDataSyncOperator(BaseOperator):
# If some were found, identify which one to run
if self.candidate_task_arns:
- self.task_arn = self.choose_task(
- self.candidate_task_arns)
+ self.task_arn = self.choose_task(self.candidate_task_arns)
# If we couldnt find one then try create one
if not self.task_arn and self.create_task_kwargs:
self._create_datasync_task()
if not self.task_arn:
- raise AirflowException(
- "DataSync TaskArn could not be identified or created.")
+ raise AirflowException("DataSync TaskArn could not be identified or created.")
self.log.info("Using DataSync TaskArn %s", self.task_arn)
@@ -227,13 +225,9 @@ class AWSDataSyncOperator(BaseOperator):
"""Find existing DataSync Task based on source and dest Locations."""
hook = self.get_hook()
- self.candidate_source_location_arns = self._get_location_arns(
- self.source_location_uri
- )
+ self.candidate_source_location_arns = self._get_location_arns(self.source_location_uri)
- self.candidate_destination_location_arns = self._get_location_arns(
- self.destination_location_uri
- )
+ self.candidate_destination_location_arns = self._get_location_arns(self.destination_location_uri)
if not self.candidate_source_location_arns:
self.log.info("No matching source Locations")
@@ -245,11 +239,9 @@ class AWSDataSyncOperator(BaseOperator):
self.log.info("Finding DataSync TaskArns that have these LocationArns")
self.candidate_task_arns = hook.get_task_arns_for_location_arns(
- self.candidate_source_location_arns,
- self.candidate_destination_location_arns,
+ self.candidate_source_location_arns, self.candidate_destination_location_arns,
)
- self.log.info("Found candidate DataSync TaskArns %s",
- self.candidate_task_arns)
+ self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns)
def choose_task(self, task_arn_list):
"""Select 1 DataSync TaskArn from a list"""
@@ -263,8 +255,7 @@ class AWSDataSyncOperator(BaseOperator):
# from AWS and might lead to confusion. Rather explicitly
# choose a random one
return random.choice(task_arn_list)
- raise AirflowException(
- "Unable to choose a Task from {}".format(task_arn_list))
+ raise AirflowException("Unable to choose a Task from {}".format(task_arn_list))
def choose_location(self, location_arn_list):
"""Select 1 DataSync LocationArn from a list"""
@@ -278,16 +269,13 @@ class AWSDataSyncOperator(BaseOperator):
# from AWS and might lead to confusion. Rather explicitly
# choose a random one
return random.choice(location_arn_list)
- raise AirflowException(
- "Unable to choose a Location from {}".format(location_arn_list))
+ raise AirflowException("Unable to choose a Location from {}".format(location_arn_list))
def _create_datasync_task(self):
"""Create a AWS DataSyncTask."""
hook = self.get_hook()
- self.source_location_arn = self.choose_location(
- self.candidate_source_location_arns
- )
+ self.source_location_arn = self.choose_location(self.candidate_source_location_arns)
if not self.source_location_arn and self.create_source_location_kwargs:
self.log.info('Attempting to create source Location')
self.source_location_arn = hook.create_location(
@@ -295,12 +283,10 @@ class AWSDataSyncOperator(BaseOperator):
)
if not self.source_location_arn:
raise AirflowException(
- "Unable to determine source LocationArn."
- " Does a suitable DataSync Location exist?")
+ "Unable to determine source LocationArn." " Does a suitable DataSync Location exist?"
+ )
- self.destination_location_arn = self.choose_location(
- self.candidate_destination_location_arns
- )
+ self.destination_location_arn = self.choose_location(self.candidate_destination_location_arns)
if not self.destination_location_arn and self.create_destination_location_kwargs:
self.log.info('Attempting to create destination Location')
self.destination_location_arn = hook.create_location(
@@ -308,14 +294,12 @@ class AWSDataSyncOperator(BaseOperator):
)
if not self.destination_location_arn:
raise AirflowException(
- "Unable to determine destination LocationArn."
- " Does a suitable DataSync Location exist?")
+ "Unable to determine destination LocationArn." " Does a suitable DataSync Location exist?"
+ )
self.log.info("Creating a Task.")
self.task_arn = hook.create_task(
- self.source_location_arn,
- self.destination_location_arn,
- **self.create_task_kwargs
+ self.source_location_arn, self.destination_location_arn, **self.create_task_kwargs
)
if not self.task_arn:
raise AirflowException("Task could not be created")
@@ -336,20 +320,15 @@ class AWSDataSyncOperator(BaseOperator):
# Create a task execution:
self.log.info("Starting execution for TaskArn %s", self.task_arn)
- self.task_execution_arn = hook.start_task_execution(
- self.task_arn, **self.task_execution_kwargs)
+ self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs)
self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)
# Wait for task execution to complete
- self.log.info("Waiting for TaskExecutionArn %s",
- self.task_execution_arn)
+ self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn)
result = hook.wait_for_task_execution(self.task_execution_arn)
self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn)
- task_execution_description = hook.describe_task_execution(
- task_execution_arn=self.task_execution_arn
- )
- self.log.info("task_execution_description=%s",
- task_execution_description)
+ task_execution_description = hook.describe_task_execution(task_execution_arn=self.task_execution_arn)
+ self.log.info("task_execution_description=%s", task_execution_description)
# Log some meaningful statuses
level = logging.ERROR if not result else logging.INFO
@@ -359,21 +338,16 @@ class AWSDataSyncOperator(BaseOperator):
self.log.log(level, '%s=%s', k, v)
if not result:
- raise AirflowException(
- "Failed TaskExecutionArn %s" % self.task_execution_arn
- )
+ raise AirflowException("Failed TaskExecutionArn %s" % self.task_execution_arn)
return self.task_execution_arn
def on_kill(self):
"""Cancel the submitted DataSync task."""
hook = self.get_hook()
if self.task_execution_arn:
- self.log.info("Cancelling TaskExecutionArn %s",
- self.task_execution_arn)
- hook.cancel_task_execution(
- task_execution_arn=self.task_execution_arn)
- self.log.info("Cancelled TaskExecutionArn %s",
- self.task_execution_arn)
+ self.log.info("Cancelling TaskExecutionArn %s", self.task_execution_arn)
+ hook.cancel_task_execution(task_execution_arn=self.task_execution_arn)
+ self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn)
def _delete_datasync_task(self):
"""Deletes an AWS DataSync Task."""
@@ -385,10 +359,6 @@ class AWSDataSyncOperator(BaseOperator):
return self.task_arn
def _get_location_arns(self, location_uri):
- location_arns = self.get_hook().get_location_arns(
- location_uri
- )
- self.log.info(
- "Found LocationArns %s for LocationUri %s", location_arns, location_uri
- )
+ location_arns = self.get_hook().get_location_arns(location_uri)
+ self.log.info("Found LocationArns %s for LocationUri %s", location_arns, location_uri)
return location_arns
diff --git a/airflow/providers/amazon/aws/operators/ec2_start_instance.py b/airflow/providers/amazon/aws/operators/ec2_start_instance.py
index dc657bf..e623de9 100644
--- a/airflow/providers/amazon/aws/operators/ec2_start_instance.py
+++ b/airflow/providers/amazon/aws/operators/ec2_start_instance.py
@@ -44,12 +44,15 @@ class EC2StartInstanceOperator(BaseOperator):
ui_fgcolor = "#ffffff"
@apply_defaults
- def __init__(self, *,
- instance_id: str,
- aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
- check_interval: float = 15,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ instance_id: str,
+ aws_conn_id: str = "aws_default",
+ region_name: Optional[str] = None,
+ check_interval: float = 15,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.instance_id = instance_id
self.aws_conn_id = aws_conn_id
@@ -57,15 +60,10 @@ class EC2StartInstanceOperator(BaseOperator):
self.check_interval = check_interval
def execute(self, context):
- ec2_hook = EC2Hook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name
- )
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
self.log.info("Starting EC2 instance %s", self.instance_id)
instance = ec2_hook.get_instance(instance_id=self.instance_id)
instance.start()
ec2_hook.wait_for_state(
- instance_id=self.instance_id,
- target_state="running",
- check_interval=self.check_interval,
+ instance_id=self.instance_id, target_state="running", check_interval=self.check_interval,
)
diff --git a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py
index 8082844..0369bdd 100644
--- a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py
+++ b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py
@@ -44,12 +44,15 @@ class EC2StopInstanceOperator(BaseOperator):
ui_fgcolor = "#ffffff"
@apply_defaults
- def __init__(self, *,
- instance_id: str,
- aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
- check_interval: float = 15,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ instance_id: str,
+ aws_conn_id: str = "aws_default",
+ region_name: Optional[str] = None,
+ check_interval: float = 15,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.instance_id = instance_id
self.aws_conn_id = aws_conn_id
@@ -57,15 +60,10 @@ class EC2StopInstanceOperator(BaseOperator):
self.check_interval = check_interval
def execute(self, context):
- ec2_hook = EC2Hook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name
- )
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
self.log.info("Stopping EC2 instance %s", self.instance_id)
instance = ec2_hook.get_instance(instance_id=self.instance_id)
instance.stop()
ec2_hook.wait_for_state(
- instance_id=self.instance_id,
- target_state="stopped",
- check_interval=self.check_interval,
+ instance_id=self.instance_id, target_state="stopped", check_interval=self.check_interval,
)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index 573b10d..44b72e5 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -113,11 +113,26 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
template_fields = ('overrides',)
@apply_defaults
- def __init__(self, *, task_definition, cluster, overrides, # pylint: disable=too-many-arguments
- aws_conn_id=None, region_name=None, launch_type='EC2',
- group=None, placement_constraints=None, platform_version='LATEST',
- network_configuration=None, tags=None, awslogs_group=None,
- awslogs_region=None, awslogs_stream_prefix=None, propagate_tags=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ task_definition,
+ cluster,
+ overrides, # pylint: disable=too-many-arguments
+ aws_conn_id=None,
+ region_name=None,
+ launch_type='EC2',
+ group=None,
+ placement_constraints=None,
+ platform_version='LATEST',
+ network_configuration=None,
+ tags=None,
+ awslogs_group=None,
+ awslogs_region=None,
+ awslogs_stream_prefix=None,
+ propagate_tags=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
@@ -144,8 +159,7 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
def execute(self, context):
self.log.info(
- 'Running ECS Task - Task definition: %s - on cluster %s',
- self.task_definition, self.cluster
+ 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster
)
self.log.info('ECSOperator overrides: %s', self.overrides)
@@ -189,16 +203,10 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
def _wait_for_task_ended(self):
waiter = self.client.get_waiter('tasks_stopped')
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
- waiter.wait(
- cluster=self.cluster,
- tasks=[self.arn]
- )
+ waiter.wait(cluster=self.cluster, tasks=[self.arn])
def _check_success_task(self):
- response = self.client.describe_tasks(
- cluster=self.cluster,
- tasks=[self.arn]
- )
+ response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn])
self.log.info('ECS Task stopped, check status: %s', response)
# Get logs from CloudWatch if the awslogs log driver was used
@@ -218,44 +226,39 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
# successfully finished, but there is no other indication of failure
# in the response.
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html
- if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.',
- task.get('stoppedReason', '')):
+ if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.', task.get('stoppedReason', '')):
raise AirflowException(
- 'The task was stopped because the host instance terminated: {}'.
- format(task.get('stoppedReason', '')))
+ 'The task was stopped because the host instance terminated: {}'.format(
+ task.get('stoppedReason', '')
+ )
+ )
containers = task['containers']
for container in containers:
- if container.get('lastStatus') == 'STOPPED' and \
- container['exitCode'] != 0:
- raise AirflowException(
- 'This task is not in success state {}'.format(task))
+ if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0:
+ raise AirflowException('This task is not in success state {}'.format(task))
elif container.get('lastStatus') == 'PENDING':
raise AirflowException('This task is still pending {}'.format(task))
elif 'error' in container.get('reason', '').lower():
raise AirflowException(
- 'This containers encounter an error during launching : {}'.
- format(container.get('reason', '').lower()))
+ 'This containers encounter an error during launching : {}'.format(
+ container.get('reason', '').lower()
+ )
+ )
def get_hook(self):
"""Create and return an AwsHook."""
if not self.hook:
self.hook = AwsBaseHook(
- aws_conn_id=self.aws_conn_id,
- client_type='ecs',
- region_name=self.region_name
+ aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name
)
return self.hook
def get_logs_hook(self):
"""Create and return an AwsLogsHook."""
- return AwsLogsHook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.awslogs_region
- )
+ return AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region)
def on_kill(self):
response = self.client.stop_task(
- cluster=self.cluster,
- task=self.arn,
- reason='Task killed by the user')
+ cluster=self.cluster, task=self.arn, reason='Task killed by the user'
+ )
self.log.info(response)
diff --git a/airflow/providers/amazon/aws/operators/emr_add_steps.py b/airflow/providers/amazon/aws/operators/emr_add_steps.py
index 3c1078e..d046d2e 100644
--- a/airflow/providers/amazon/aws/operators/emr_add_steps.py
+++ b/airflow/providers/amazon/aws/operators/emr_add_steps.py
@@ -44,19 +44,22 @@ class EmrAddStepsOperator(BaseOperator):
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
:type do_xcom_push: bool
"""
+
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
template_ext = ('.json',)
ui_color = '#f9c915'
@apply_defaults
def __init__(
- self, *,
- job_flow_id=None,
- job_flow_name=None,
- cluster_states=None,
- aws_conn_id='aws_default',
- steps=None,
- **kwargs):
+ self,
+ *,
+ job_flow_id=None,
+ job_flow_name=None,
+ cluster_states=None,
+ aws_conn_id='aws_default',
+ steps=None,
+ **kwargs,
+ ):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not (job_flow_id is None) ^ (job_flow_name is None):
@@ -74,8 +77,9 @@ class EmrAddStepsOperator(BaseOperator):
emr = emr_hook.get_conn()
- job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name,
- self.cluster_states)
+ job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(
+ self.job_flow_name, self.cluster_states
+ )
if not job_flow_id:
raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
diff --git a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py
index b02abf9..71e5e09 100644
--- a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py
+++ b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py
@@ -37,18 +37,21 @@ class EmrCreateJobFlowOperator(BaseOperator):
(must be '.json') to override emr_connection extra. (templated)
:type job_flow_overrides: dict|str
"""
+
template_fields = ['job_flow_overrides']
template_ext = ('.json',)
ui_color = '#f9c915'
@apply_defaults
def __init__(
- self, *,
- aws_conn_id='aws_default',
- emr_conn_id='emr_default',
- job_flow_overrides=None,
- region_name=None,
- **kwargs):
+ self,
+ *,
+ aws_conn_id='aws_default',
+ emr_conn_id='emr_default',
+ job_flow_overrides=None,
+ region_name=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
@@ -58,13 +61,12 @@ class EmrCreateJobFlowOperator(BaseOperator):
self.region_name = region_name
def execute(self, context):
- emr = EmrHook(aws_conn_id=self.aws_conn_id,
- emr_conn_id=self.emr_conn_id,
- region_name=self.region_name)
+ emr = EmrHook(
+ aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name
+ )
self.log.info(
- 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
- self.aws_conn_id, self.emr_conn_id
+ 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id
)
if isinstance(self.job_flow_overrides, str):
diff --git a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py
index 87c2296..48692e3 100644
--- a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py
+++ b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py
@@ -33,17 +33,15 @@ class EmrModifyClusterOperator(BaseOperator):
:param do_xcom_push: if True, cluster_id is pushed to XCom with key cluster_id.
:type do_xcom_push: bool
"""
+
template_fields = ['cluster_id', 'step_concurrency_level']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
- self, *,
- cluster_id: str,
- step_concurrency_level: int,
- aws_conn_id: str = 'aws_default',
- **kwargs):
+ self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs
+ ):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
super().__init__(**kwargs)
@@ -60,8 +58,9 @@ class EmrModifyClusterOperator(BaseOperator):
context['ti'].xcom_push(key='cluster_id', value=self.cluster_id)
self.log.info('Modifying cluster %s', self.cluster_id)
- response = emr.modify_cluster(ClusterId=self.cluster_id,
- StepConcurrencyLevel=self.step_concurrency_level)
+ response = emr.modify_cluster(
+ ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level
+ )
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Modify cluster failed: %s' % response)
diff --git a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
index c22920e..19cbddb 100644
--- a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
+++ b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
@@ -30,16 +30,13 @@ class EmrTerminateJobFlowOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
+
template_fields = ['job_flow_id']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
- def __init__(
- self, *,
- job_flow_id,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, job_flow_id, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py
index a945f4e..991135f 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -52,25 +52,28 @@ class AwsGlueJobOperator(BaseOperator):
:param iam_role_name: AWS IAM Role for Glue Job Execution
:type iam_role_name: Optional[str]
"""
+
template_fields = ()
template_ext = ()
ui_color = '#ededed'
@apply_defaults
- def __init__(self, *,
- job_name='aws_glue_default_job',
- job_desc='AWS Glue Job with Airflow',
- script_location=None,
- concurrent_run_limit=None,
- script_args=None,
- retry_limit=None,
- num_of_dpus=6,
- aws_conn_id='aws_default',
- region_name=None,
- s3_bucket=None,
- iam_role_name=None,
- **kwargs
- ): # pylint: disable=too-many-arguments
+ def __init__(
+ self,
+ *,
+ job_name='aws_glue_default_job',
+ job_desc='AWS Glue Job with Airflow',
+ script_location=None,
+ concurrent_run_limit=None,
+ script_args=None,
+ retry_limit=None,
+ num_of_dpus=6,
+ aws_conn_id='aws_default',
+ region_name=None,
+ s3_bucket=None,
+ iam_role_name=None,
+ **kwargs,
+ ): # pylint: disable=too-many-arguments
super(AwsGlueJobOperator, self).__init__(**kwargs)
self.job_name = job_name
self.job_desc = job_desc
@@ -96,20 +99,25 @@ class AwsGlueJobOperator(BaseOperator):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
s3_hook.load_file(self.script_location, self.s3_bucket, self.s3_artifcats_prefix + script_name)
- glue_job = AwsGlueJobHook(job_name=self.job_name,
- desc=self.job_desc,
- concurrent_run_limit=self.concurrent_run_limit,
- script_location=self.script_location,
- retry_limit=self.retry_limit,
- num_of_dpus=self.num_of_dpus,
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name,
- s3_bucket=self.s3_bucket,
- iam_role_name=self.iam_role_name)
+ glue_job = AwsGlueJobHook(
+ job_name=self.job_name,
+ desc=self.job_desc,
+ concurrent_run_limit=self.concurrent_run_limit,
+ script_location=self.script_location,
+ retry_limit=self.retry_limit,
+ num_of_dpus=self.num_of_dpus,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ s3_bucket=self.s3_bucket,
+ iam_role_name=self.iam_role_name,
+ )
self.log.info("Initializing AWS Glue Job: %s", self.job_name)
glue_job_run = glue_job.initialize_job(self.script_args)
glue_job_run = glue_job.job_completion(self.job_name, glue_job_run['JobRunId'])
self.log.info(
"AWS Glue Job: %s status: %s. Run Id: %s",
- self.job_name, glue_job_run['JobRunState'], glue_job_run['JobRunId'])
+ self.job_name,
+ glue_job_run['JobRunState'],
+ glue_job_run['JobRunId'],
+ )
return glue_job_run['JobRunId']
diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py b/airflow/providers/amazon/aws/operators/s3_bucket.py
index f7d9822..a2aa06b 100644
--- a/airflow/providers/amazon/aws/operators/s3_bucket.py
+++ b/airflow/providers/amazon/aws/operators/s3_bucket.py
@@ -40,12 +40,16 @@ class S3CreateBucketOperator(BaseOperator):
:param region_name: AWS region_name. If not specified fetched from connection.
:type region_name: Optional[str]
"""
+
@apply_defaults
- def __init__(self, *,
- bucket_name,
- aws_conn_id: Optional[str] = "aws_default",
- region_name: Optional[str] = None,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ bucket_name,
+ aws_conn_id: Optional[str] = "aws_default",
+ region_name: Optional[str] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.region_name = region_name
@@ -76,11 +80,14 @@ class S3DeleteBucketOperator(BaseOperator):
maintained on each worker node).
:type aws_conn_id: Optional[str]
"""
- def __init__(self,
- bucket_name,
- force_delete: Optional[bool] = False,
- aws_conn_id: Optional[str] = "aws_default",
- **kwargs) -> None:
+
+ def __init__(
+ self,
+ bucket_name,
+ force_delete: Optional[bool] = False,
+ aws_conn_id: Optional[str] = "aws_default",
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.force_delete = force_delete
diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py b/airflow/providers/amazon/aws/operators/s3_copy_object.py
index 8d1dd9c..4b2d290 100644
--- a/airflow/providers/amazon/aws/operators/s3_copy_object.py
+++ b/airflow/providers/amazon/aws/operators/s3_copy_object.py
@@ -64,20 +64,21 @@ class S3CopyObjectOperator(BaseOperator):
:type verify: bool or str
"""
- template_fields = ('source_bucket_key', 'dest_bucket_key',
- 'source_bucket_name', 'dest_bucket_name')
+ template_fields = ('source_bucket_key', 'dest_bucket_key', 'source_bucket_name', 'dest_bucket_name')
@apply_defaults
def __init__(
- self, *,
- source_bucket_key,
- dest_bucket_key,
- source_bucket_name=None,
- dest_bucket_name=None,
- source_version_id=None,
- aws_conn_id='aws_default',
- verify=None,
- **kwargs):
+ self,
+ *,
+ source_bucket_key,
+ dest_bucket_key,
+ source_bucket_name=None,
+ dest_bucket_name=None,
+ source_version_id=None,
+ aws_conn_id='aws_default',
+ verify=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.source_bucket_key = source_bucket_key
@@ -90,6 +91,10 @@ class S3CopyObjectOperator(BaseOperator):
def execute(self, context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
- s3_hook.copy_object(self.source_bucket_key, self.dest_bucket_key,
- self.source_bucket_name, self.dest_bucket_name,
- self.source_version_id)
+ s3_hook.copy_object(
+ self.source_bucket_key,
+ self.dest_bucket_key,
+ self.source_bucket_name,
+ self.dest_bucket_name,
+ self.source_version_id,
+ )
diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py b/airflow/providers/amazon/aws/operators/s3_delete_objects.py
index d8c4683..b6d267b 100644
--- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py
+++ b/airflow/providers/amazon/aws/operators/s3_delete_objects.py
@@ -62,14 +62,7 @@ class S3DeleteObjectsOperator(BaseOperator):
template_fields = ('keys', 'bucket', 'prefix')
@apply_defaults
- def __init__(
- self, *,
- bucket,
- keys=None,
- prefix=None,
- aws_conn_id='aws_default',
- verify=None,
- **kwargs):
+ def __init__(self, *, bucket, keys=None, prefix=None, aws_conn_id='aws_default', verify=None, **kwargs):
if not bool(keys) ^ bool(prefix):
raise ValueError("Either keys or prefix should be set.")
diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py
index 4324d20..e2aa822 100644
--- a/airflow/providers/amazon/aws/operators/s3_file_transform.py
+++ b/airflow/providers/amazon/aws/operators/s3_file_transform.py
@@ -84,18 +84,20 @@ class S3FileTransformOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
- source_s3_key: str,
- dest_s3_key: str,
- transform_script: Optional[str] = None,
- select_expression=None,
- script_args: Optional[Sequence[str]] = None,
- source_aws_conn_id: str = 'aws_default',
- source_verify: Optional[Union[bool, str]] = None,
- dest_aws_conn_id: str = 'aws_default',
- dest_verify: Optional[Union[bool, str]] = None,
- replace: bool = False,
- **kwargs) -> None:
+ self,
+ *,
+ source_s3_key: str,
+ dest_s3_key: str,
+ transform_script: Optional[str] = None,
+ select_expression=None,
+ script_args: Optional[Sequence[str]] = None,
+ source_aws_conn_id: str = 'aws_default',
+ source_verify: Optional[Union[bool, str]] = None,
+ dest_aws_conn_id: str = 'aws_default',
+ dest_verify: Optional[Union[bool, str]] = None,
+ replace: bool = False,
+ **kwargs,
+ ) -> None:
# pylint: disable=too-many-arguments
super().__init__(**kwargs)
self.source_s3_key = source_s3_key
@@ -112,29 +114,21 @@ class S3FileTransformOperator(BaseOperator):
def execute(self, context):
if self.transform_script is None and self.select_expression is None:
- raise AirflowException(
- "Either transform_script or select_expression must be specified")
+ raise AirflowException("Either transform_script or select_expression must be specified")
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify)
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
self.log.info("Downloading source S3 file %s", self.source_s3_key)
if not source_s3.check_for_key(self.source_s3_key):
- raise AirflowException(
- "The source key {0} does not exist".format(self.source_s3_key))
+ raise AirflowException("The source key {0} does not exist".format(self.source_s3_key))
source_s3_key_object = source_s3.get_key(self.source_s3_key)
with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
- self.log.info(
- "Dumping S3 file %s contents to local file %s",
- self.source_s3_key, f_source.name
- )
+ self.log.info("Dumping S3 file %s contents to local file %s", self.source_s3_key, f_source.name)
if self.select_expression is not None:
- content = source_s3.select_key(
- key=self.source_s3_key,
- expression=self.select_expression
- )
+ content = source_s3.select_key(key=self.source_s3_key, expression=self.select_expression)
f_source.write(content.encode("utf-8"))
else:
source_s3_key_object.download_fileobj(Fileobj=f_source)
@@ -145,7 +139,7 @@ class S3FileTransformOperator(BaseOperator):
[self.transform_script, f_source.name, f_dest.name, *self.script_args],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
- close_fds=True
+ close_fds=True,
)
self.log.info("Output:")
@@ -155,13 +149,10 @@ class S3FileTransformOperator(BaseOperator):
process.wait()
if process.returncode:
- raise AirflowException(
- "Transform script failed: {0}".format(process.returncode)
- )
+ raise AirflowException("Transform script failed: {0}".format(process.returncode))
else:
self.log.info(
- "Transform script successful. Output temporarily located at %s",
- f_dest.name
+ "Transform script successful. Output temporarily located at %s", f_dest.name
)
self.log.info("Uploading transformed file to S3")
@@ -169,6 +160,6 @@ class S3FileTransformOperator(BaseOperator):
dest_s3.load_file(
filename=f_dest.name if self.transform_script else f_source.name,
key=self.dest_s3_key,
- replace=self.replace
+ replace=self.replace,
)
self.log.info("Upload successful")
diff --git a/airflow/providers/amazon/aws/operators/s3_list.py b/airflow/providers/amazon/aws/operators/s3_list.py
index 427ff3f..4c25e99 100644
--- a/airflow/providers/amazon/aws/operators/s3_list.py
+++ b/airflow/providers/amazon/aws/operators/s3_list.py
@@ -65,17 +65,12 @@ class S3ListOperator(BaseOperator):
aws_conn_id='aws_customers_conn'
)
"""
+
template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter')
ui_color = '#ffd700'
@apply_defaults
- def __init__(self, *,
- bucket,
- prefix='',
- delimiter='',
- aws_conn_id='aws_default',
- verify=None,
- **kwargs):
+ def __init__(self, *, bucket, prefix='', delimiter='', aws_conn_id='aws_default', verify=None, **kwargs):
super().__init__(**kwargs)
self.bucket = bucket
self.prefix = prefix
@@ -88,10 +83,9 @@ class S3ListOperator(BaseOperator):
self.log.info(
'Getting the list of files from bucket: %s in prefix: %s (Delimiter {%s)',
- self.bucket, self.prefix, self.delimiter
+ self.bucket,
+ self.prefix,
+ self.delimiter,
)
- return hook.list_keys(
- bucket_name=self.bucket,
- prefix=self.prefix,
- delimiter=self.delimiter)
+ return hook.list_keys(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py
index e5c42ac..19fb921 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_base.py
@@ -41,10 +41,7 @@ class SageMakerBaseOperator(BaseOperator):
integer_fields = [] # type: Iterable[Iterable[str]]
@apply_defaults
- def __init__(self, *,
- config,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, config, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
@@ -81,14 +78,12 @@ class SageMakerBaseOperator(BaseOperator):
for field in self.integer_fields:
self.parse_integer(self.config, field)
- def expand_role(self): # noqa: D402
+ def expand_role(self): # noqa: D402
"""Placeholder for calling boto3's expand_role(), which expands an IAM role name into an ARN."""
def preprocess_config(self):
"""Process the config into a usable form."""
- self.log.info(
- 'Preprocessing the config and doing required s3_operations'
- )
+ self.log.info('Preprocessing the config and doing required s3_operations')
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.hook.configure_s3_resources(self.config)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
index aa444fa..c7a89f2 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
@@ -71,15 +71,17 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
"""
@apply_defaults
- def __init__(self, *,
- config,
- wait_for_completion=True,
- check_interval=30,
- max_ingestion_time=None,
- operation='create',
- **kwargs):
- super().__init__(config=config,
- **kwargs)
+ def __init__(
+ self,
+ *,
+ config,
+ wait_for_completion=True,
+ check_interval=30,
+ max_ingestion_time=None,
+ operation='create',
+ **kwargs,
+ ):
+ super().__init__(config=config, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
@@ -93,9 +95,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
def create_integer_fields(self):
"""Set fields which should be casted to integers."""
if 'EndpointConfig' in self.config:
- self.integer_fields = [
- ['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']
- ]
+ self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
def expand_role(self):
if 'Model' not in self.config:
@@ -135,7 +135,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time
+ max_ingestion_time=self.max_ingestion_time,
)
except ClientError: # Botocore throws a ClientError if the endpoint is already created
self.operation = 'update'
@@ -145,18 +145,13 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time
+ max_ingestion_time=self.max_ingestion_time,
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(
- 'Sagemaker endpoint creation failed: %s' % response)
+ raise AirflowException('Sagemaker endpoint creation failed: %s' % response)
else:
return {
- 'EndpointConfig': self.hook.describe_endpoint_config(
- endpoint_info['EndpointConfigName']
- ),
- 'Endpoint': self.hook.describe_endpoint(
- endpoint_info['EndpointName']
- )
+ 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
+ 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
index f1d38bf..9bde451 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
@@ -35,16 +35,11 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
:type aws_conn_id: str
"""
- integer_fields = [
- ['ProductionVariants', 'InitialInstanceCount']
- ]
+ integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
@apply_defaults
- def __init__(self, *,
- config,
- **kwargs):
- super().__init__(config=config,
- **kwargs)
+ def __init__(self, *, config, **kwargs):
+ super().__init__(config=config, **kwargs)
self.config = config
@@ -54,11 +49,6 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
response = self.hook.create_endpoint_config(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(
- 'Sagemaker endpoint config creation failed: %s' % response)
+ raise AirflowException('Sagemaker endpoint config creation failed: %s' % response)
else:
- return {
- 'EndpointConfig': self.hook.describe_endpoint_config(
- self.config['EndpointConfigName']
- )
- }
+ return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py
index 31e2fbd..122ceee 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_model.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_model.py
@@ -37,11 +37,8 @@ class SageMakerModelOperator(SageMakerBaseOperator):
"""
@apply_defaults
- def __init__(self, *,
- config,
- **kwargs):
- super().__init__(config=config,
- **kwargs)
+ def __init__(self, *, config, **kwargs):
+ super().__init__(config=config, **kwargs)
self.config = config
@@ -58,8 +55,4 @@ class SageMakerModelOperator(SageMakerBaseOperator):
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker model creation failed: %s' % response)
else:
- return {
- 'Model': self.hook.describe_model(
- self.config['ModelName']
- )
- }
+ return {'Model': self.hook.describe_model(self.config['ModelName'])}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
index ef2fd69..c1bcac7 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
@@ -52,15 +52,18 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
"""
@apply_defaults
- def __init__(self, *,
- config,
- aws_conn_id,
- wait_for_completion=True,
- print_log=True,
- check_interval=30,
- max_ingestion_time=None,
- action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
- **kwargs):
+ def __init__(
+ self,
+ *,
+ config,
+ aws_conn_id,
+ wait_for_completion=True,
+ print_log=True,
+ check_interval=30,
+ max_ingestion_time=None,
+ action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
+ **kwargs,
+ ):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
if action_if_job_exists not in ("increment", "fail"):
@@ -79,12 +82,10 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
"""Set fields which should be casted to integers."""
self.integer_fields = [
['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
- ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB']
+ ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
]
if 'StoppingCondition' in self.config:
- self.integer_fields += [
- ['StoppingCondition', 'MaxRuntimeInSeconds']
- ]
+ self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
def expand_role(self):
if 'RoleArn' in self.config:
@@ -114,12 +115,8 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time
+ max_ingestion_time=self.max_ingestion_time,
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Processing Job creation failed: %s' % response)
- return {
- 'Processing': self.hook.describe_processing_job(
- self.config['ProcessingJobName']
- )
- }
+ return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py
index 9bdbe56..6175a61 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py
@@ -54,18 +54,21 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
integer_fields = [
['ResourceConfig', 'InstanceCount'],
['ResourceConfig', 'VolumeSizeInGB'],
- ['StoppingCondition', 'MaxRuntimeInSeconds']
+ ['StoppingCondition', 'MaxRuntimeInSeconds'],
]
@apply_defaults
- def __init__(self, *,
- config,
- wait_for_completion=True,
- print_log=True,
- check_interval=30,
- max_ingestion_time=None,
- action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
- **kwargs):
+ def __init__(
+ self,
+ *,
+ config,
+ wait_for_completion=True,
+ print_log=True,
+ check_interval=30,
+ max_ingestion_time=None,
+ action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8
+ **kwargs,
+ ):
super().__init__(config=config, **kwargs)
self.wait_for_completion = wait_for_completion
@@ -110,13 +113,9 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
wait_for_completion=self.wait_for_completion,
print_log=self.print_log,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time
+ max_ingestion_time=self.max_ingestion_time,
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Training Job creation failed: %s' % response)
else:
- return {
- 'Training': self.hook.describe_training_job(
- self.config['TrainingJobName']
- )
- }
+ return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
index 221bf82..7ae8f3a 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
@@ -62,14 +62,10 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
"""
@apply_defaults
- def __init__(self, *,
- config,
- wait_for_completion=True,
- check_interval=30,
- max_ingestion_time=None,
- **kwargs):
- super().__init__(config=config,
- **kwargs)
+ def __init__(
+ self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs
+ ):
+ super().__init__(config=config, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
@@ -81,7 +77,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
self.integer_fields = [
['Transform', 'TransformResources', 'InstanceCount'],
['Transform', 'MaxConcurrentTransforms'],
- ['Transform', 'MaxPayloadInMB']
+ ['Transform', 'MaxPayloadInMB'],
]
if 'Transform' not in self.config:
for field in self.integer_fields:
@@ -110,15 +106,12 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
transform_config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time)
+ max_ingestion_time=self.max_ingestion_time,
+ )
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker transform Job creation failed: %s' % response)
else:
return {
- 'Model': self.hook.describe_model(
- transform_config['ModelName']
- ),
- 'Transform': self.hook.describe_transform_job(
- transform_config['TransformJobName']
- )
+ 'Model': self.hook.describe_model(transform_config['ModelName']),
+ 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
}
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
index 1626886..483e541 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
@@ -51,18 +51,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
- ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds']
+ ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
]
@apply_defaults
- def __init__(self, *,
- config,
- wait_for_completion=True,
- check_interval=30,
- max_ingestion_time=None,
- **kwargs):
- super().__init__(config=config,
- **kwargs)
+ def __init__(
+ self, *, config, wait_for_completion=True, check_interval=30, max_ingestion_time=None, **kwargs
+ ):
+ super().__init__(config=config, **kwargs)
self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
@@ -86,13 +82,9 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
self.config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
- max_ingestion_time=self.max_ingestion_time
+ max_ingestion_time=self.max_ingestion_time,
)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException('Sagemaker Tuning Job creation failed: %s' % response)
else:
- return {
- 'Tuning': self.hook.describe_tuning_job(
- self.config['HyperParameterTuningJobName']
- )
- }
+ return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])}
diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py
index 3f24813..8917dfe 100644
--- a/airflow/providers/amazon/aws/operators/sns.py
+++ b/airflow/providers/amazon/aws/operators/sns.py
@@ -39,18 +39,21 @@ class SnsPublishOperator(BaseOperator):
determined automatically)
:type message_attributes: dict
"""
+
template_fields = ['message', 'subject', 'message_attributes']
template_ext = ()
@apply_defaults
def __init__(
- self, *,
- target_arn,
- message,
- aws_conn_id='aws_default',
- subject=None,
- message_attributes=None,
- **kwargs):
+ self,
+ *,
+ target_arn,
+ message,
+ aws_conn_id='aws_default',
+ subject=None,
+ message_attributes=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.target_arn = target_arn
self.message = message
diff --git a/airflow/providers/amazon/aws/operators/sqs.py b/airflow/providers/amazon/aws/operators/sqs.py
index e0edc3f..00b29db 100644
--- a/airflow/providers/amazon/aws/operators/sqs.py
+++ b/airflow/providers/amazon/aws/operators/sqs.py
@@ -38,17 +38,21 @@ class SQSPublishOperator(BaseOperator):
:param aws_conn_id: AWS connection id (default: aws_default)
:type aws_conn_id: str
"""
+
template_fields = ('sqs_queue', 'message_content', 'delay_seconds')
ui_color = '#6ad3fa'
@apply_defaults
- def __init__(self, *,
- sqs_queue,
- message_content,
- message_attributes=None,
- delay_seconds=0,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(
+ self,
+ *,
+ sqs_queue,
+ message_content,
+ message_attributes=None,
+ delay_seconds=0,
+ aws_conn_id='aws_default',
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
@@ -69,10 +73,12 @@ class SQSPublishOperator(BaseOperator):
hook = SQSHook(aws_conn_id=self.aws_conn_id)
- result = hook.send_message(queue_url=self.sqs_queue,
- message_body=self.message_content,
- delay_seconds=self.delay_seconds,
- message_attributes=self.message_attributes)
+ result = hook.send_message(
+ queue_url=self.sqs_queue,
+ message_body=self.message_content,
+ delay_seconds=self.delay_seconds,
+ message_attributes=self.message_attributes,
+ )
self.log.info('result is send_message is %s', result)
diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
index 404ce24..2eaa2c4 100644
--- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
+++ b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
@@ -36,6 +36,7 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:type aws_conn_id: str
"""
+
template_fields = ['execution_arn']
template_ext = ()
ui_color = '#f9c915'
diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
index 0b22c88..0d8f446 100644
--- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py
+++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
@@ -43,15 +43,22 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
:type do_xcom_push: bool
"""
+
template_fields = ['state_machine_arn', 'name', 'input']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
- def __init__(self, *, state_machine_arn: str, name: Optional[str] = None,
- state_machine_input: Union[dict, str, None] = None,
- aws_conn_id='aws_default', region_name=None,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ state_machine_arn: str,
+ name: Optional[str] = None,
+ state_machine_input: Union[dict, str, None] = None,
+ aws_conn_id='aws_default',
+ region_name=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.state_machine_arn = state_machine_arn
self.name = name
diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py
index 39dd8a7..47a07a9 100644
--- a/airflow/providers/amazon/aws/secrets/secrets_manager.py
+++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py
@@ -70,7 +70,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
config_prefix: str = 'airflow/config',
profile_name: Optional[str] = None,
sep: str = "/",
- **kwargs
+ **kwargs,
):
super().__init__()
self.connections_prefix = connections_prefix.rstrip("/")
@@ -85,9 +85,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
Create a Secrets Manager client
"""
- session = boto3.session.Session(
- profile_name=self.profile_name,
- )
+ session = boto3.session.Session(profile_name=self.profile_name,)
return session.client(service_name="secretsmanager", **self.kwargs)
def get_conn_uri(self, conn_id: str) -> Optional[str]:
@@ -128,14 +126,13 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
secrets_path = self.build_path(path_prefix, secret_id, self.sep)
try:
- response = self.client.get_secret_value(
- SecretId=secrets_path,
- )
+ response = self.client.get_secret_value(SecretId=secrets_path,)
return response.get('SecretString')
except self.client.exceptions.ResourceNotFoundException:
self.log.debug(
"An error occurred (ResourceNotFoundException) when calling the "
"get_secret_value operation: "
- "Secret %s not found.", secrets_path
+ "Secret %s not found.",
+ secrets_path,
)
return None
diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py
index 203be35..5e67362 100644
--- a/airflow/providers/amazon/aws/secrets/systems_manager.py
+++ b/airflow/providers/amazon/aws/secrets/systems_manager.py
@@ -57,7 +57,7 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
connections_prefix: str = '/airflow/connections',
variables_prefix: str = '/airflow/variables',
profile_name: Optional[str] = None,
- **kwargs
+ **kwargs,
):
super().__init__()
self.connections_prefix = connections_prefix.rstrip("/")
@@ -102,14 +102,13 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
"""
ssm_path = self.build_path(path_prefix, secret_id)
try:
- response = self.client.get_parameter(
- Name=ssm_path, WithDecryption=True
- )
+ response = self.client.get_parameter(Name=ssm_path, WithDecryption=True)
value = response["Parameter"]["Value"]
return value
except self.client.exceptions.ParameterNotFound:
self.log.info(
"An error occurred (ParameterNotFound) when calling the GetParameter operation: "
- "Parameter %s not found.", ssm_path
+ "Parameter %s not found.",
+ ssm_path,
)
return None
diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py
index 50edc8d..40c028a 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -42,8 +42,14 @@ class AthenaSensor(BaseSensorOperator):
:type sleep_time: int
"""
- INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',)
- FAILURE_STATES = ('FAILED', 'CANCELLED',)
+ INTERMEDIATE_STATES = (
+ 'QUEUED',
+ 'RUNNING',
+ )
+ FAILURE_STATES = (
+ 'FAILED',
+ 'CANCELLED',
+ )
SUCCESS_STATES = ('SUCCEEDED',)
template_fields = ['query_execution_id']
@@ -51,12 +57,15 @@ class AthenaSensor(BaseSensorOperator):
ui_color = '#66c3ff'
@apply_defaults
- def __init__(self, *,
- query_execution_id: str,
- max_retries: Optional[int] = None,
- aws_conn_id: str = 'aws_default',
- sleep_time: int = 10,
- **kwargs: Any) -> None:
+ def __init__(
+ self,
+ *,
+ query_execution_id: str,
+ max_retries: Optional[int] = None,
+ aws_conn_id: str = 'aws_default',
+ sleep_time: int = 10,
+ **kwargs: Any,
+ ) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.query_execution_id = query_execution_id
diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py
index 05f15a3..739a133 100644
--- a/airflow/providers/amazon/aws/sensors/cloud_formation.py
+++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py
@@ -40,11 +40,7 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
ui_color = '#C5CAE9'
@apply_defaults
- def __init__(self, *,
- stack_name,
- aws_conn_id='aws_default',
- region_name=None,
- **kwargs):
+ def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.hook = AWSCloudFormationHook(aws_conn_id=aws_conn_id, region_name=region_name)
@@ -75,11 +71,7 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
ui_color = '#C5CAE9'
@apply_defaults
- def __init__(self, *,
- stack_name,
- aws_conn_id='aws_default',
- region_name=None,
- **kwargs):
+ def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
@@ -97,7 +89,5 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
def get_hook(self):
"""Create and return an AWSCloudFormationHook"""
if not self.hook:
- self.hook = AWSCloudFormationHook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name)
+ self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py
index c2a53c8..7e55d7d 100644
--- a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py
+++ b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py
@@ -43,12 +43,15 @@ class EC2InstanceStateSensor(BaseSensorOperator):
valid_states = ["running", "stopped", "terminated"]
@apply_defaults
- def __init__(self, *,
- target_state: str,
- instance_id: str,
- aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ target_state: str,
+ instance_id: str,
+ aws_conn_id: str = "aws_default",
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
if target_state not in self.valid_states:
raise ValueError(f"Invalid target_state: {target_state}")
super().__init__(**kwargs)
@@ -58,12 +61,7 @@ class EC2InstanceStateSensor(BaseSensorOperator):
self.region_name = region_name
def poke(self, context):
- ec2_hook = EC2Hook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name
- )
- instance_state = ec2_hook.get_instance_state(
- instance_id=self.instance_id
- )
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+ instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id)
self.log.info("instance state: %s", instance_state)
return instance_state == self.target_state
diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py
index d487af2..f05197b 100644
--- a/airflow/providers/amazon/aws/sensors/emr_base.py
+++ b/airflow/providers/amazon/aws/sensors/emr_base.py
@@ -38,13 +38,11 @@ class EmrBaseSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
"""
+
ui_color = '#66c3ff'
@apply_defaults
- def __init__(
- self, *,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.target_states = None # will be set in subclasses
@@ -86,8 +84,7 @@ class EmrBaseSensor(BaseSensorOperator):
:return: response
:rtype: dict[str, Any]
"""
- raise NotImplementedError(
- 'Please implement get_emr_response() in subclass')
+ raise NotImplementedError('Please implement get_emr_response() in subclass')
@staticmethod
def state_from_response(response: Dict[str, Any]) -> str:
@@ -99,8 +96,7 @@ class EmrBaseSensor(BaseSensorOperator):
:return: state
:rtype: str
"""
- raise NotImplementedError(
- 'Please implement state_from_response() in subclass')
+ raise NotImplementedError('Please implement state_from_response() in subclass')
@staticmethod
def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
@@ -112,5 +108,4 @@ class EmrBaseSensor(BaseSensorOperator):
:return: failure message
:rtype: Optional[str]
"""
- raise NotImplementedError(
- 'Please implement failure_message_from_response() in subclass')
+ raise NotImplementedError('Please implement failure_message_from_response() in subclass')
diff --git a/airflow/providers/amazon/aws/sensors/emr_job_flow.py b/airflow/providers/amazon/aws/sensors/emr_job_flow.py
index 004b8b8..c08e9db 100644
--- a/airflow/providers/amazon/aws/sensors/emr_job_flow.py
+++ b/airflow/providers/amazon/aws/sensors/emr_job_flow.py
@@ -46,11 +46,14 @@ class EmrJobFlowSensor(EmrBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *,
- job_flow_id: str,
- target_states: Optional[Iterable[str]] = None,
- failed_states: Optional[Iterable[str]] = None,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ job_flow_id: str,
+ target_states: Optional[Iterable[str]] = None,
+ failed_states: Optional[Iterable[str]] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.target_states = target_states or ['TERMINATED']
@@ -97,6 +100,6 @@ class EmrJobFlowSensor(EmrBaseSensor):
state_change_reason = cluster_status.get('StateChangeReason')
if state_change_reason:
return 'for code: {} with message {}'.format(
- state_change_reason.get('Code', 'No code'),
- state_change_reason.get('Message', 'Unknown'))
+ state_change_reason.get('Code', 'No code'), state_change_reason.get('Message', 'Unknown')
+ )
return None
diff --git a/airflow/providers/amazon/aws/sensors/emr_step.py b/airflow/providers/amazon/aws/sensors/emr_step.py
index 65394c8..f3c3d59 100644
--- a/airflow/providers/amazon/aws/sensors/emr_step.py
+++ b/airflow/providers/amazon/aws/sensors/emr_step.py
@@ -41,23 +41,24 @@ class EmrStepSensor(EmrBaseSensor):
:type failed_states: list[str]
"""
- template_fields = ['job_flow_id', 'step_id',
- 'target_states', 'failed_states']
+ template_fields = ['job_flow_id', 'step_id', 'target_states', 'failed_states']
template_ext = ()
@apply_defaults
- def __init__(self, *,
- job_flow_id: str,
- step_id: str,
- target_states: Optional[Iterable[str]] = None,
- failed_states: Optional[Iterable[str]] = None,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ job_flow_id: str,
+ step_id: str,
+ target_states: Optional[Iterable[str]] = None,
+ failed_states: Optional[Iterable[str]] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
self.target_states = target_states or ['COMPLETED']
- self.failed_states = failed_states or ['CANCELLED', 'FAILED',
- 'INTERRUPTED']
+ self.failed_states = failed_states or ['CANCELLED', 'FAILED', 'INTERRUPTED']
def get_emr_response(self) -> Dict[str, Any]:
"""
@@ -71,12 +72,8 @@ class EmrStepSensor(EmrBaseSensor):
"""
emr_client = self.get_hook().get_conn()
- self.log.info('Poking step %s on cluster %s',
- self.step_id,
- self.job_flow_id)
- return emr_client.describe_step(
- ClusterId=self.job_flow_id,
- StepId=self.step_id)
+ self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
+ return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
@staticmethod
def state_from_response(response: Dict[str, Any]) -> str:
@@ -103,7 +100,6 @@ class EmrStepSensor(EmrBaseSensor):
fail_details = response['Step']['Status'].get('FailureDetails')
if fail_details:
return 'for reason {} with message {} and log file {}'.format(
- fail_details.get('Reason'),
- fail_details.get('Message'),
- fail_details.get('LogFile'))
+ fail_details.get('Reason'), fail_details.get('Message'), fail_details.get('LogFile')
+ )
return None
diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py
index 9539761..7b2ce30 100644
--- a/airflow/providers/amazon/aws/sensors/glue.py
+++ b/airflow/providers/amazon/aws/sensors/glue.py
@@ -32,14 +32,11 @@ class AwsGlueJobSensor(BaseSensorOperator):
:param run_id: The AWS Glue current running job identifier
:type run_id: str
"""
+
template_fields = ('job_name', 'run_id')
@apply_defaults
- def __init__(self, *,
- job_name,
- run_id,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, job_name, run_id, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
self.run_id = run_id
@@ -49,9 +46,7 @@ class AwsGlueJobSensor(BaseSensorOperator):
def poke(self, context):
hook = AwsGlueJobHook(aws_conn_id=self.aws_conn_id)
- self.log.info(
- "Poking for job run status :"
- "for Glue Job %s and ID %s", self.job_name, self.run_id)
+ self.log.info("Poking for job run status :" "for Glue Job %s and ID %s", self.job_name, self.run_id)
job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id)
if job_state in self.success_states:
self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state)
diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
index 5d900ab..f1df94d 100644
--- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
+++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
@@ -47,19 +47,27 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
between each tries
:type poke_interval: int
"""
- template_fields = ('database_name', 'table_name', 'expression',)
+
+ template_fields = (
+ 'database_name',
+ 'table_name',
+ 'expression',
+ )
ui_color = '#C5CAE9'
@apply_defaults
- def __init__(self, *,
- table_name, expression="ds='{{ ds }}'",
- aws_conn_id='aws_default',
- region_name=None,
- database_name='default',
- poke_interval=60 * 3,
- **kwargs):
- super().__init__(
- poke_interval=poke_interval, **kwargs)
+ def __init__(
+ self,
+ *,
+ table_name,
+ expression="ds='{{ ds }}'",
+ aws_conn_id='aws_default',
+ region_name=None,
+ database_name='default',
+ poke_interval=60 * 3,
+ **kwargs,
+ ):
+ super().__init__(poke_interval=poke_interval, **kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.table_name = table_name
@@ -77,15 +85,12 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
)
- return self.get_hook().check_for_partition(
- self.database_name, self.table_name, self.expression)
+ return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression)
def get_hook(self):
"""
Gets the AwsGlueCatalogHook
"""
if not self.hook:
- self.hook = AwsGlueCatalogHook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name)
+ self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/redshift.py b/airflow/providers/amazon/aws/sensors/redshift.py
index 0c893ca..37f3521 100644
--- a/airflow/providers/amazon/aws/sensors/redshift.py
+++ b/airflow/providers/amazon/aws/sensors/redshift.py
@@ -30,14 +30,11 @@ class AwsRedshiftClusterSensor(BaseSensorOperator):
:param target_status: The cluster status desired.
:type target_status: str
"""
+
template_fields = ('cluster_identifier', 'target_status')
@apply_defaults
- def __init__(self, *,
- cluster_identifier,
- target_status='available',
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, cluster_identifier, target_status='available', aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.target_status = target_status
@@ -45,8 +42,7 @@ class AwsRedshiftClusterSensor(BaseSensorOperator):
self.hook = None
def poke(self, context):
- self.log.info('Poking for status : %s\nfor cluster %s',
- self.target_status, self.cluster_identifier)
+ self.log.info('Poking for status : %s\nfor cluster %s', self.target_status, self.cluster_identifier)
return self.get_hook().cluster_status(self.cluster_identifier) == self.target_status
def get_hook(self):
diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py
index 2661daa..0c0f6e3 100644
--- a/airflow/providers/amazon/aws/sensors/s3_key.py
+++ b/airflow/providers/amazon/aws/sensors/s3_key.py
@@ -55,16 +55,20 @@ class S3KeySensor(BaseSensorOperator):
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
+
template_fields = ('bucket_key', 'bucket_name')
@apply_defaults
- def __init__(self, *,
- bucket_key,
- bucket_name=None,
- wildcard_match=False,
- aws_conn_id='aws_default',
- verify=None,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ bucket_key,
+ bucket_name=None,
+ wildcard_match=False,
+ aws_conn_id='aws_default',
+ verify=None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
# Parse
if bucket_name is None:
@@ -77,9 +81,11 @@ class S3KeySensor(BaseSensorOperator):
else:
parsed_url = urlparse(bucket_key)
if parsed_url.scheme != '' or parsed_url.netloc != '':
- raise AirflowException('If bucket_name is provided, bucket_key' +
- ' should be relative path from root' +
- ' level, rather than a full s3:// url')
+ raise AirflowException(
+ 'If bucket_name is provided, bucket_key'
+ + ' should be relative path from root'
+ + ' level, rather than a full s3:// url'
+ )
self.bucket_name = bucket_name
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
@@ -90,9 +96,7 @@ class S3KeySensor(BaseSensorOperator):
def poke(self, context):
self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
if self.wildcard_match:
- return self.get_hook().check_for_wildcard_key(
- self.bucket_key,
- self.bucket_name)
+ return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name)
return self.get_hook().check_for_key(self.bucket_key, self.bucket_name)
def get_hook(self):
diff --git a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py
index 95a2148..f1f3d4e 100644
--- a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py
+++ b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py
@@ -72,16 +72,19 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
template_fields = ('bucket_name', 'prefix')
@apply_defaults
- def __init__(self, *,
- bucket_name: str,
- prefix: str,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- inactivity_period: float = 60 * 60,
- min_objects: int = 1,
- previous_objects: Optional[Set[str]] = None,
- allow_delete: bool = True,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ bucket_name: str,
+ prefix: str,
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[bool, str]] = None,
+ inactivity_period: float = 60 * 60,
+ min_objects: int = 1,
+ previous_objects: Optional[Set[str]] = None,
+ allow_delete: bool = True,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
@@ -117,8 +120,10 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
if current_objects > self.previous_objects:
# When new objects arrived, reset the inactivity_seconds
# and update previous_objects for the next poke.
- self.log.info("New objects found at %s, resetting last_activity_time.",
- os.path.join(self.bucket, self.prefix))
+ self.log.info(
+ "New objects found at %s, resetting last_activity_time.",
+ os.path.join(self.bucket, self.prefix),
+ )
self.log.debug("New objects: %s", current_objects - self.previous_objects)
self.last_activity_time = datetime.now()
self.inactivity_seconds = 0
@@ -131,12 +136,17 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
deleted_objects = self.previous_objects - current_objects
self.previous_objects = current_objects
self.last_activity_time = datetime.now()
- self.log.info("Objects were deleted during the last poke interval. Updating the "
- "file counter and resetting last_activity_time:\n%s", deleted_objects)
+ self.log.info(
+ "Objects were deleted during the last poke interval. Updating the "
+ "file counter and resetting last_activity_time:\n%s",
+ deleted_objects,
+ )
return False
- raise AirflowException("Illegal behavior: objects were deleted in %s between pokes."
- % os.path.join(self.bucket, self.prefix))
+ raise AirflowException(
+ "Illegal behavior: objects were deleted in %s between pokes."
+ % os.path.join(self.bucket, self.prefix)
+ )
if self.last_activity_time:
self.inactivity_seconds = int((datetime.now() - self.last_activity_time).total_seconds())
@@ -149,9 +159,13 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
path = os.path.join(self.bucket, self.prefix)
if current_num_objects >= self.min_objects:
- self.log.info("SUCCESS: \nSensor found %s objects at %s.\n"
- "Waited at least %s seconds, with no new objects uploaded.",
- current_num_objects, path, self.inactivity_period)
+ self.log.info(
+ "SUCCESS: \nSensor found %s objects at %s.\n"
+ "Waited at least %s seconds, with no new objects uploaded.",
+ current_num_objects,
+ path,
+ self.inactivity_period,
+ )
return True
self.log.error("FAILURE: Inactivity Period passed, not enough objects found in %s", path)
diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py b/airflow/providers/amazon/aws/sensors/s3_prefix.py
index acaf961..4dc4900 100644
--- a/airflow/providers/amazon/aws/sensors/s3_prefix.py
+++ b/airflow/providers/amazon/aws/sensors/s3_prefix.py
@@ -51,16 +51,13 @@ class S3PrefixSensor(BaseSensorOperator):
CA cert bundle than the one used by botocore.
:type verify: bool or str
"""
+
template_fields = ('prefix', 'bucket_name')
@apply_defaults
- def __init__(self, *,
- bucket_name,
- prefix,
- delimiter='/',
- aws_conn_id='aws_default',
- verify=None,
- **kwargs):
+ def __init__(
+ self, *, bucket_name, prefix, delimiter='/', aws_conn_id='aws_default', verify=None, **kwargs
+ ):
super().__init__(**kwargs)
# Parse
self.bucket_name = bucket_name
@@ -74,9 +71,8 @@ class S3PrefixSensor(BaseSensorOperator):
def poke(self, context):
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix, self.bucket_name)
return self.get_hook().check_for_prefix(
- prefix=self.prefix,
- delimiter=self.delimiter,
- bucket_name=self.bucket_name)
+ prefix=self.prefix, delimiter=self.delimiter, bucket_name=self.bucket_name
+ )
def get_hook(self):
"""Create and return an S3Hook"""
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
index b3468df..6704b1a 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
@@ -28,13 +28,11 @@ class SageMakerBaseSensor(BaseSensorOperator):
and state_from_response() methods.
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
"""
+
ui_color = '#ededed'
@apply_defaults
- def __init__(
- self, *,
- aws_conn_id='aws_default',
- **kwargs):
+ def __init__(self, *, aws_conn_id='aws_default', **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.hook = None
@@ -61,8 +59,7 @@ class SageMakerBaseSensor(BaseSensorOperator):
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
- raise AirflowException('Sagemaker job failed for the following reason: %s'
- % failed_reason)
+ raise AirflowException('Sagemaker job failed for the following reason: %s' % failed_reason)
return True
def non_terminal_states(self):
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
index b8df5bf..1a1b6f7 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
@@ -34,9 +34,7 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *,
- endpoint_name,
- **kwargs):
+ def __init__(self, *, endpoint_name, **kwargs):
super().__init__(**kwargs)
self.endpoint_name = endpoint_name
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
index 1695d95..36403b8 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
@@ -38,10 +38,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *,
- job_name,
- print_log=True,
- **kwargs):
+ def __init__(self, *, job_name, print_log=True, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
self.print_log = print_log
@@ -75,20 +72,27 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
if self.print_log:
if not self.log_resource_inited:
self.init_log_resource(self.get_hook())
- self.state, self.last_description, self.last_describe_job_call = \
- self.get_hook().describe_training_job_with_log(self.job_name,
- self.positions, self.stream_names,
- self.instance_count, self.state,
- self.last_description,
- self.last_describe_job_call)
+ (
+ self.state,
+ self.last_description,
+ self.last_describe_job_call,
+ ) = self.get_hook().describe_training_job_with_log(
+ self.job_name,
+ self.positions,
+ self.stream_names,
+ self.instance_count,
+ self.state,
+ self.last_description,
+ self.last_describe_job_call,
+ )
else:
self.last_description = self.get_hook().describe_training_job(self.job_name)
status = self.state_from_response(self.last_description)
if status not in self.non_terminal_states() and status not in self.failed_states():
- billable_time = \
- (self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']) * \
- self.last_description['ResourceConfig']['InstanceCount']
+ billable_time = (
+ self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']
+ ) * self.last_description['ResourceConfig']['InstanceCount']
self.log.info('Billable seconds: %s', int(billable_time.total_seconds()) + 1)
return self.last_description
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
index 5a9ffdc..4108c98 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
@@ -35,9 +35,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *,
- job_name,
- **kwargs):
+ def __init__(self, *, job_name, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
index 6b97807..794695b 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
@@ -35,9 +35,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *,
- job_name,
- **kwargs):
+ def __init__(self, *, job_name, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index 573981b..2d1ab54 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -44,12 +44,9 @@ class SQSSensor(BaseSensorOperator):
template_fields = ('sqs_queue', 'max_messages')
@apply_defaults
- def __init__(self, *,
- sqs_queue,
- aws_conn_id='aws_default',
- max_messages=5,
- wait_time_seconds=1,
- **kwargs):
+ def __init__(
+ self, *, sqs_queue, aws_conn_id='aws_default', max_messages=5, wait_time_seconds=1, **kwargs
+ ):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
@@ -69,25 +66,29 @@ class SQSSensor(BaseSensorOperator):
self.log.info('SQSSensor checking for message on queue: %s', self.sqs_queue)
- messages = sqs_conn.receive_message(QueueUrl=self.sqs_queue,
- MaxNumberOfMessages=self.max_messages,
- WaitTimeSeconds=self.wait_time_seconds)
+ messages = sqs_conn.receive_message(
+ QueueUrl=self.sqs_queue,
+ MaxNumberOfMessages=self.max_messages,
+ WaitTimeSeconds=self.wait_time_seconds,
+ )
self.log.info("received message %s", str(messages))
if 'Messages' in messages and messages['Messages']:
- entries = [{'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
- for message in messages['Messages']]
+ entries = [
+ {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']}
+ for message in messages['Messages']
+ ]
- result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue,
- Entries=entries)
+ result = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
if 'Successful' in result:
context['ti'].xcom_push(key='messages', value=messages)
return True
else:
raise AirflowException(
- 'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages))
+ 'Delete SQS Messages failed ' + str(result) + ' for messages ' + str(messages)
+ )
return False
diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py
index a0e640e..6126670 100644
--- a/airflow/providers/amazon/aws/sensors/step_function_execution.py
+++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py
@@ -39,7 +39,11 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
"""
INTERMEDIATE_STATES = ('RUNNING',)
- FAILURE_STATES = ('FAILED', 'TIMED_OUT', 'ABORTED',)
+ FAILURE_STATES = (
+ 'FAILED',
+ 'TIMED_OUT',
+ 'ABORTED',
+ )
SUCCESS_STATES = ('SUCCEEDED',)
template_fields = ['execution_arn']
@@ -47,8 +51,7 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
ui_color = '#66c3ff'
@apply_defaults
- def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None,
- **kwargs):
+ def __init__(self, *, execution_arn: str, aws_conn_id='aws_default', region_name=None, **kwargs):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index 7f71a54..40bb026 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -42,9 +42,7 @@ def _upload_file_to_s3(file_obj, bucket_name, s3_key_prefix):
s3_client = S3Hook().get_conn()
file_obj.seek(0)
s3_client.upload_file(
- Filename=file_obj.name,
- Bucket=bucket_name,
- Key=s3_key_prefix + str(uuid4()),
+ Filename=file_obj.name, Bucket=bucket_name, Key=s3_key_prefix + str(uuid4()),
)
@@ -92,14 +90,17 @@ class DynamoDBToS3Operator(BaseOperator):
"""
@apply_defaults
- def __init__(self, *,
- dynamodb_table_name: str,
- s3_bucket_name: str,
- file_size: int,
- dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None,
- s3_key_prefix: str = '',
- process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ dynamodb_table_name: str,
+ s3_bucket_name: str,
+ file_size: int,
+ dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None,
+ s3_key_prefix: str = '',
+ process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.file_size = file_size
self.process_func = process_func
@@ -139,8 +140,7 @@ class DynamoDBToS3Operator(BaseOperator):
# Upload the file to S3 if reach file size limit
if getsize(temp_file.name) >= self.file_size:
- _upload_file_to_s3(temp_file, self.s3_bucket_name,
- self.s3_key_prefix)
+ _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix)
temp_file.close()
temp_file = NamedTemporaryFile()
return temp_file
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index 29b04dc..212c978 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -83,30 +83,42 @@ class GCSToS3Operator(BaseOperator):
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
"""
- template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_s3_key',
- 'google_impersonation_chain',)
+
+ template_fields: Iterable[str] = (
+ 'bucket',
+ 'prefix',
+ 'delimiter',
+ 'dest_s3_key',
+ 'google_impersonation_chain',
+ )
ui_color = '#f0eee4'
@apply_defaults
- def __init__(self, *, # pylint: disable=too-many-arguments
- bucket,
- prefix=None,
- delimiter=None,
- gcp_conn_id='google_cloud_default',
- google_cloud_storage_conn_id=None,
- delegate_to=None,
- dest_aws_conn_id=None,
- dest_s3_key=None,
- dest_verify=None,
- replace=False,
- google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs):
+ def __init__(
+ self,
+ *, # pylint: disable=too-many-arguments
+ bucket,
+ prefix=None,
+ delimiter=None,
+ gcp_conn_id='google_cloud_default',
+ google_cloud_storage_conn_id=None,
+ delegate_to=None,
+ dest_aws_conn_id=None,
+ dest_s3_key=None,
+ dest_verify=None,
+ replace=False,
+ google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
if google_cloud_storage_conn_id:
warnings.warn(
"The google_cloud_storage_conn_id parameter has been deprecated. You should pass "
- "the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
+ "the gcp_conn_id parameter.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
gcp_conn_id = google_cloud_storage_conn_id
self.bucket = bucket
@@ -128,12 +140,14 @@ class GCSToS3Operator(BaseOperator):
impersonation_chain=self.google_impersonation_chain,
)
- self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s',
- self.bucket, self.delimiter, self.prefix)
+ self.log.info(
+ 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s',
+ self.bucket,
+ self.delimiter,
+ self.prefix,
+ )
- files = hook.list(bucket_name=self.bucket,
- prefix=self.prefix,
- delimiter=self.delimiter)
+ files = hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter)
s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)
@@ -159,9 +173,7 @@ class GCSToS3Operator(BaseOperator):
dest_key = self.dest_s3_key + file
self.log.info("Saving file to %s", dest_key)
- s3_hook.load_bytes(file_bytes,
- key=dest_key,
- replace=self.replace)
+ s3_hook.load_bytes(file_bytes, key=dest_key, replace=self.replace)
self.log.info("All done, uploaded %d files to S3", len(files))
else:
diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
index 85695de..ca17bed 100644
--- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
@@ -96,7 +96,8 @@ class GoogleApiToS3Operator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
google_api_service_name,
google_api_service_version,
google_api_endpoint_path,
@@ -112,7 +113,7 @@ class GoogleApiToS3Operator(BaseOperator):
delegate_to=None,
aws_conn_id='aws_default',
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- **kwargs
+ **kwargs,
):
super().__init__(**kwargs)
self.google_api_service_name = google_api_service_name
@@ -162,22 +163,20 @@ class GoogleApiToS3Operator(BaseOperator):
endpoint=self.google_api_endpoint_path,
data=self.google_api_endpoint_params,
paginate=self.google_api_pagination,
- num_retries=self.google_api_num_retries
+ num_retries=self.google_api_num_retries,
)
return google_api_response
def _load_data_to_s3(self, data):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
s3_hook.load_string(
- string_data=json.dumps(data),
- key=self.s3_destination_key,
- replace=self.s3_overwrite
+ string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite
)
def _update_google_api_endpoint_params_via_xcom(self, task_instance):
google_api_endpoint_params = task_instance.xcom_pull(
task_ids=self.google_api_endpoint_params_via_xcom_task_ids,
- key=self.google_api_endpoint_params_via_xcom
+ key=self.google_api_endpoint_params_via_xcom,
)
self.google_api_endpoint_params.update(google_api_endpoint_params)
diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
index 3eecaca..0ca7218 100644
--- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
+++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
@@ -62,18 +62,20 @@ class HiveToDynamoDBOperator(BaseOperator):
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self, *,
- sql,
- table_name,
- table_keys,
- pre_process=None,
- pre_process_args=None,
- pre_process_kwargs=None,
- region_name=None,
- schema='default',
- hiveserver2_conn_id='hiveserver2_default',
- aws_conn_id='aws_default',
- **kwargs):
+ self,
+ *,
+ sql,
+ table_name,
+ table_keys,
+ pre_process=None,
+ pre_process_args=None,
+ pre_process_kwargs=None,
+ region_name=None,
+ schema='default',
+ hiveserver2_conn_id='hiveserver2_default',
+ aws_conn_id='aws_default',
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.sql = sql
self.table_name = table_name
@@ -93,20 +95,20 @@ class HiveToDynamoDBOperator(BaseOperator):
self.log.info(self.sql)
data = hive.get_pandas_df(self.sql, schema=self.schema)
- dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id,
- table_name=self.table_name,
- table_keys=self.table_keys,
- region_name=self.region_name)
+ dynamodb = AwsDynamoDBHook(
+ aws_conn_id=self.aws_conn_id,
+ table_name=self.table_name,
+ table_keys=self.table_keys,
+ region_name=self.region_name,
+ )
self.log.info('Inserting rows into dynamodb')
if self.pre_process is None:
- dynamodb.write_batch_data(
- json.loads(data.to_json(orient='records')))
+ dynamodb.write_batch_data(json.loads(data.to_json(orient='records')))
else:
dynamodb.write_batch_data(
- self.pre_process(data=data,
- args=self.pre_process_args,
- kwargs=self.pre_process_kwargs))
+ self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)
+ )
self.log.info('Done.')
diff --git a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
index 79505f1..bf65b8f 100644
--- a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
@@ -50,19 +50,23 @@ class ImapAttachmentToS3Operator(BaseOperator):
:param s3_conn_id: The reference to the s3 connection details.
:type s3_conn_id: str
"""
+
template_fields = ('imap_attachment_name', 's3_key', 'imap_mail_filter')
@apply_defaults
- def __init__(self, *,
- imap_attachment_name,
- s3_key,
- imap_check_regex=False,
- imap_mail_folder='INBOX',
- imap_mail_filter='All',
- s3_overwrite=False,
- imap_conn_id='imap_default',
- s3_conn_id='aws_default',
- **kwargs):
+ def __init__(
+ self,
+ *,
+ imap_attachment_name,
+ s3_key,
+ imap_check_regex=False,
+ imap_mail_folder='INBOX',
+ imap_mail_filter='All',
+ s3_overwrite=False,
+ imap_conn_id='imap_default',
+ s3_conn_id='aws_default',
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.imap_attachment_name = imap_attachment_name
self.s3_key = s3_key
@@ -82,7 +86,8 @@ class ImapAttachmentToS3Operator(BaseOperator):
"""
self.log.info(
'Transferring mail attachment %s from mail server via imap to s3 key %s...',
- self.imap_attachment_name, self.s3_key
+ self.imap_attachment_name,
+ self.s3_key,
)
with ImapHook(imap_conn_id=self.imap_conn_id) as imap_hook:
@@ -95,6 +100,4 @@ class ImapAttachmentToS3Operator(BaseOperator):
)
s3_hook = S3Hook(aws_conn_id=self.s3_conn_id)
- s3_hook.load_bytes(bytes_data=imap_mail_attachments[0][1],
- key=self.s3_key,
- replace=self.s3_overwrite)
+ s3_hook.load_bytes(bytes_data=imap_mail_attachments[0][1], key=self.s3_key, replace=self.s3_overwrite)
diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index 214689c..b996e10 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -41,16 +41,19 @@ class MongoToS3Operator(BaseOperator):
# pylint: disable=too-many-instance-attributes
@apply_defaults
- def __init__(self, *,
- mongo_conn_id,
- s3_conn_id,
- mongo_collection,
- mongo_query,
- s3_bucket,
- s3_key,
- mongo_db=None,
- replace=False,
- **kwargs):
+ def __init__(
+ self,
+ *,
+ mongo_conn_id,
+ s3_conn_id,
+ mongo_collection,
+ mongo_query,
+ s3_bucket,
+ s3_key,
+ mongo_db=None,
+ replace=False,
+ **kwargs,
+ ):
super().__init__(**kwargs)
# Conn Ids
self.mongo_conn_id = mongo_conn_id
@@ -78,14 +81,12 @@ class MongoToS3Operator(BaseOperator):
results = MongoHook(self.mongo_conn_id).aggregate(
mongo_collection=self.mongo_collection,
aggregate_query=self.mongo_query,
- mongo_db=self.mongo_db
+ mongo_db=self.mongo_db,
)
else:
results = MongoHook(self.mongo_conn_id).find(
- mongo_collection=self.mongo_collection,
- query=self.mongo_query,
- mongo_db=self.mongo_db
+ mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db
)
# Performs transform then stringifies the docs results into json format
@@ -93,10 +94,7 @@ class MongoToS3Operator(BaseOperator):
# Load Into S3
s3_conn.load_string(
- string_data=docs_str,
- key=self.s3_key,
- bucket_name=self.s3_bucket,
- replace=self.replace
+ string_data=docs_str, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace
)
return True
@@ -107,9 +105,7 @@ class MongoToS3Operator(BaseOperator):
Takes an iterable (pymongo Cursor or Array) containing dictionaries and
returns a stringified version using python join
"""
- return joinable.join(
- [json.dumps(doc, default=json_util.default) for doc in iterable]
- )
+ return joinable.join([json.dumps(doc, default=json_util.default) for doc in iterable])
@staticmethod
def transform(docs):
diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py
index 249e4b2..7a376f1 100644
--- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py
@@ -63,22 +63,27 @@ class MySQLToS3Operator(BaseOperator):
:type header: bool
"""
- template_fields = ('s3_key', 'query',)
+ template_fields = (
+ 's3_key',
+ 'query',
+ )
template_ext = ('.sql',)
@apply_defaults
def __init__(
- self, *,
- query: str,
- s3_bucket: str,
- s3_key: str,
- mysql_conn_id: str = 'mysql_default',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- pd_csv_kwargs: Optional[dict] = None,
- index: Optional[bool] = False,
- header: Optional[bool] = False,
- **kwargs) -> None:
+ self,
+ *,
+ query: str,
+ s3_bucket: str,
+ s3_key: str,
+ mysql_conn_id: str = 'mysql_default',
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[bool, str]] = None,
+ pd_csv_kwargs: Optional[dict] = None,
+ index: Optional[bool] = False,
+ header: Optional[bool] = False,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.query = query
self.s3_bucket = s3_bucket
@@ -116,9 +121,7 @@ class MySQLToS3Operator(BaseOperator):
self._fix_int_dtypes(data_df)
with NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv:
data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs)
- s3_conn.load_file(filename=tmp_csv.name,
- key=self.s3_key,
- bucket_name=self.s3_bucket)
+ s3_conn.load_file(filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket)
if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket):
file_location = os.path.join(self.s3_bucket, self.s3_key)
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index 9f1b113..3a3e6c2 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -71,19 +71,21 @@ class RedshiftToS3Operator(BaseOperator):
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self, *,
- schema: str,
- table: str,
- s3_bucket: str,
- s3_key: str,
- redshift_conn_id: str = 'redshift_default',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- unload_options: Optional[List] = None,
- autocommit: bool = False,
- include_header: bool = False,
- table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
- **kwargs) -> None:
+ self,
+ *,
+ schema: str,
+ table: str,
+ s3_bucket: str,
+ s3_key: str,
+ redshift_conn_id: str = 'redshift_default',
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[bool, str]] = None,
+ unload_options: Optional[List] = None,
+ autocommit: bool = False,
+ include_header: bool = False,
+ table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.schema = schema
self.table = table
@@ -98,7 +100,9 @@ class RedshiftToS3Operator(BaseOperator):
self.table_as_file_name = table_as_file_name
if self.include_header and 'HEADER' not in [uo.upper().strip() for uo in self.unload_options]:
- self.unload_options = list(self.unload_options) + ['HEADER', ]
+ self.unload_options = list(self.unload_options) + [
+ 'HEADER',
+ ]
def execute(self, context):
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
@@ -114,12 +118,14 @@ class RedshiftToS3Operator(BaseOperator):
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
{unload_options};
- """.format(select_query=select_query,
- s3_bucket=self.s3_bucket,
- s3_key=s3_key,
- access_key=credentials.access_key,
- secret_key=credentials.secret_key,
- unload_options=unload_options)
+ """.format(
+ select_query=select_query,
+ s3_bucket=self.s3_bucket,
+ s3_key=s3_key,
+ access_key=credentials.access_key,
+ secret_key=credentials.secret_key,
+ unload_options=unload_options,
+ )
self.log.info('Executing UNLOAD command...')
postgres_hook.run(unload_query, self.autocommit)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 1ddbeae..3b2afd7 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -64,17 +64,19 @@ class S3ToRedshiftOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
- schema: str,
- table: str,
- s3_bucket: str,
- s3_key: str,
- redshift_conn_id: str = 'redshift_default',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- copy_options: Optional[List] = None,
- autocommit: bool = False,
- **kwargs) -> None:
+ self,
+ *,
+ schema: str,
+ table: str,
+ s3_bucket: str,
+ s3_key: str,
+ redshift_conn_id: str = 'redshift_default',
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[bool, str]] = None,
+ copy_options: Optional[List] = None,
+ autocommit: bool = False,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.schema = schema
self.table = table
@@ -100,13 +102,15 @@ class S3ToRedshiftOperator(BaseOperator):
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
{copy_options};
- """.format(schema=self.schema,
- table=self.table,
- s3_bucket=self.s3_bucket,
- s3_key=self.s3_key,
- access_key=credentials.access_key,
- secret_key=credentials.secret_key,
- copy_options=copy_options)
+ """.format(
+ schema=self.schema,
+ table=self.table,
+ s3_bucket=self.s3_bucket,
+ s3_key=self.s3_key,
+ access_key=credentials.access_key,
+ secret_key=credentials.secret_key,
+ copy_options=copy_options,
+ )
self.log.info('Executing COPY command...')
self._postgres_hook.run(copy_query, self.autocommit)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
index fd9246d..fe87c69 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
@@ -49,13 +49,9 @@ class S3ToSFTPOperator(BaseOperator):
template_fields = ('s3_key', 'sftp_path')
@apply_defaults
- def __init__(self, *,
- s3_bucket,
- s3_key,
- sftp_path,
- sftp_conn_id='ssh_default',
- s3_conn_id='aws_default',
- **kwargs):
+ def __init__(
+ self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
+ ):
super().__init__(**kwargs)
self.sftp_conn_id = sftp_conn_id
self.sftp_path = sftp_path
diff --git a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
index c1b6e65..087eb74 100644
--- a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
@@ -49,13 +49,9 @@ class SFTPToS3Operator(BaseOperator):
template_fields = ('s3_key', 'sftp_path')
@apply_defaults
- def __init__(self, *,
- s3_bucket,
- s3_key,
- sftp_path,
- sftp_conn_id='ssh_default',
- s3_conn_id='aws_default',
- **kwargs):
+ def __init__(
+ self, *, s3_bucket, s3_key, sftp_path, sftp_conn_id='ssh_default', s3_conn_id='aws_default', **kwargs
+ ):
super().__init__(**kwargs)
self.sftp_conn_id = sftp_conn_id
self.sftp_path = sftp_path
@@ -80,9 +76,4 @@ class SFTPToS3Operator(BaseOperator):
with NamedTemporaryFile("w") as f:
sftp_client.get(self.sftp_path, f.name)
- s3_hook.load_file(
- filename=f.name,
- key=self.s3_key,
- bucket_name=self.s3_bucket,
- replace=True
- )
+ s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True)
diff --git a/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py b/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py
index ce73634..b28fa84 100644
--- a/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py
+++ b/airflow/providers/apache/cassandra/example_dags/example_cassandra_dag.py
@@ -34,7 +34,7 @@ with DAG(
default_args=args,
schedule_interval=None,
start_date=days_ago(2),
- tags=['example']
+ tags=['example'],
) as dag:
# [START howto_operator_cassandra_table_sensor]
table_sensor = CassandraTableSensor(
diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py
index b360885..71aea78 100644
--- a/airflow/providers/apache/cassandra/hooks/cassandra.py
+++ b/airflow/providers/apache/cassandra/hooks/cassandra.py
@@ -25,7 +25,10 @@ from typing import Any, Dict, Union
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster, Session
from cassandra.policies import (
- DCAwareRoundRobinPolicy, RoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy,
+ DCAwareRoundRobinPolicy,
+ RoundRobinPolicy,
+ TokenAwarePolicy,
+ WhiteListRoundRobinPolicy,
)
from airflow.hooks.base_hook import BaseHook
@@ -81,6 +84,7 @@ class CassandraHook(BaseHook, LoggingMixin):
For details of the Cluster config, see cassandra.cluster.
"""
+
def __init__(self, cassandra_conn_id: str = 'cassandra_default'):
super().__init__()
conn = self.get_connection(cassandra_conn_id)
@@ -93,8 +97,7 @@ class CassandraHook(BaseHook, LoggingMixin):
conn_config['port'] = int(conn.port)
if conn.login:
- conn_config['auth_provider'] = PlainTextAuthProvider(
- username=conn.login, password=conn.password)
+ conn_config['auth_provider'] = PlainTextAuthProvider(username=conn.login, password=conn.password)
policy_name = conn.extra_dejson.get('load_balancing_policy', None)
policy_args = conn.extra_dejson.get('load_balancing_policy_args', {})
@@ -158,17 +161,17 @@ class CassandraHook(BaseHook, LoggingMixin):
return WhiteListRoundRobinPolicy(hosts)
if policy_name == 'TokenAwarePolicy':
- allowed_child_policies = ('RoundRobinPolicy',
- 'DCAwareRoundRobinPolicy',
- 'WhiteListRoundRobinPolicy',)
- child_policy_name = policy_args.get('child_load_balancing_policy',
- 'RoundRobinPolicy')
+ allowed_child_policies = (
+ 'RoundRobinPolicy',
+ 'DCAwareRoundRobinPolicy',
+ 'WhiteListRoundRobinPolicy',
+ )
+ child_policy_name = policy_args.get('child_load_balancing_policy', 'RoundRobinPolicy')
child_policy_args = policy_args.get('child_load_balancing_policy_args', {})
if child_policy_name not in allowed_child_policies:
return TokenAwarePolicy(RoundRobinPolicy())
else:
- child_policy = CassandraHook.get_lb_policy(child_policy_name,
- child_policy_args)
+ child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
return TokenAwarePolicy(child_policy)
# Fallback to default RoundRobinPolicy
@@ -186,8 +189,7 @@ class CassandraHook(BaseHook, LoggingMixin):
if '.' in table:
keyspace, table = table.split('.', 1)
cluster_metadata = self.get_conn().cluster.metadata
- return (keyspace in cluster_metadata.keyspaces and
- table in cluster_metadata.keyspaces[keyspace].tables)
+ return keyspace in cluster_metadata.keyspaces and table in cluster_metadata.keyspaces[keyspace].tables
def record_exists(self, table: str, keys: Dict[str, str]) -> bool:
"""
diff --git a/airflow/providers/apache/cassandra/sensors/record.py b/airflow/providers/apache/cassandra/sensors/record.py
index ea67ac5..bc61b29 100644
--- a/airflow/providers/apache/cassandra/sensors/record.py
+++ b/airflow/providers/apache/cassandra/sensors/record.py
@@ -53,6 +53,7 @@ class CassandraRecordSensor(BaseSensorOperator):
when connecting to Cassandra cluster
:type cassandra_conn_id: str
"""
+
template_fields = ('table', 'keys')
@apply_defaults
diff --git a/airflow/providers/apache/cassandra/sensors/table.py b/airflow/providers/apache/cassandra/sensors/table.py
index 82cd411..64129d7 100644
--- a/airflow/providers/apache/cassandra/sensors/table.py
+++ b/airflow/providers/apache/cassandra/sensors/table.py
@@ -51,6 +51,7 @@ class CassandraTableSensor(BaseSensorOperator):
when connecting to Cassandra cluster
:type cassandra_conn_id: str
"""
+
template_fields = ('table',)
@apply_defaults
diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py
index 3dbc5b9..b609c4a 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -49,7 +49,7 @@ class DruidHook(BaseHook):
self,
druid_ingest_conn_id: str = 'druid_ingest_default',
timeout: int = 1,
- max_ingestion_time: Optional[int] = None
+ max_ingestion_time: Optional[int] = None,
) -> None:
super().__init__()
@@ -71,7 +71,8 @@ class DruidHook(BaseHook):
conn_type = 'http' if not conn.conn_type else conn.conn_type
endpoint = conn.extra_dejson.get('endpoint', '')
return "{conn_type}://{host}:{port}/{endpoint}".format(
- conn_type=conn_type, host=host, port=port, endpoint=endpoint)
+ conn_type=conn_type, host=host, port=port, endpoint=endpoint
+ )
def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]:
"""
@@ -96,8 +97,7 @@ class DruidHook(BaseHook):
self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth())
if req_index.status_code != 200:
- raise AirflowException('Did not get 200 when '
- 'submitting the Druid job to {}'.format(url))
+ raise AirflowException('Did not get 200 when ' 'submitting the Druid job to {}'.format(url))
req_json = req_index.json()
# Wait until the job is completed
@@ -115,8 +115,7 @@ class DruidHook(BaseHook):
if self.max_ingestion_time and sec > self.max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
requests.post("{0}/{1}/shutdown".format(url, druid_task_id), auth=self.get_auth())
- raise AirflowException('Druid ingestion took more than '
- f'{self.max_ingestion_time} seconds')
+ raise AirflowException('Druid ingestion took more than ' f'{self.max_ingestion_time} seconds')
time.sleep(self.timeout)
@@ -128,8 +127,7 @@ class DruidHook(BaseHook):
elif status == 'SUCCESS':
running = False # Great success!
elif status == 'FAILED':
- raise AirflowException('Druid indexing job failed, '
- 'check console for more info')
+ raise AirflowException('Druid indexing job failed, ' 'check console for more info')
else:
raise AirflowException(f'Could not get status of the job, got {status}')
@@ -143,6 +141,7 @@ class DruidDbApiHook(DbApiHook):
This hook is purely for users to query druid broker.
For ingestion, please use druidHook.
"""
+
conn_name_attr = 'druid_broker_conn_id'
default_conn_name = 'druid_broker_default'
supports_autocommit = False
@@ -158,7 +157,7 @@ class DruidDbApiHook(DbApiHook):
path=conn.extra_dejson.get('endpoint', '/druid/v2/sql'),
scheme=conn.extra_dejson.get('schema', 'http'),
user=conn.login,
- password=conn.password
+ password=conn.password,
)
self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login)
return druid_broker_conn
@@ -175,14 +174,18 @@ class DruidDbApiHook(DbApiHook):
host += ':{port}'.format(port=conn.port)
conn_type = 'druid' if not conn.conn_type else conn.conn_type
endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql')
- return '{conn_type}://{host}/{endpoint}'.format(
- conn_type=conn_type, host=host, endpoint=endpoint)
+ return '{conn_type}://{host}/{endpoint}'.format(conn_type=conn_type, host=host, endpoint=endpoint)
def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplemented:
raise NotImplementedError()
- def insert_rows(self, table: str, rows: Iterable[Tuple[str]],
- target_fields: Optional[Iterable[str]] = None,
- commit_every: int = 1000, replace: bool = False,
- **kwargs: Any) -> NotImplemented:
+ def insert_rows(
+ self,
+ table: str,
+ rows: Iterable[Tuple[str]],
+ target_fields: Optional[Iterable[str]] = None,
+ commit_every: int = 1000,
+ replace: bool = False,
+ **kwargs: Any,
+ ) -> NotImplemented:
raise NotImplementedError()
diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py
index f046ff1..1ad665d 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -34,23 +34,25 @@ class DruidOperator(BaseOperator):
accepts index jobs
:type druid_ingest_conn_id: str
"""
+
template_fields = ('json_index_file',)
template_ext = ('.json',)
@apply_defaults
- def __init__(self, *, json_index_file: str,
- druid_ingest_conn_id: str = 'druid_ingest_default',
- max_ingestion_time: Optional[int] = None,
- **kwargs: Any) -> None:
+ def __init__(
+ self,
+ *,
+ json_index_file: str,
+ druid_ingest_conn_id: str = 'druid_ingest_default',
+ max_ingestion_time: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
super().__init__(**kwargs)
self.json_index_file = json_index_file
self.conn_id = druid_ingest_conn_id
self.max_ingestion_time = max_ingestion_time
def execute(self, context: Dict[Any, Any]) -> None:
- hook = DruidHook(
- druid_ingest_conn_id=self.conn_id,
- max_ingestion_time=self.max_ingestion_time
- )
+ hook = DruidHook(druid_ingest_conn_id=self.conn_id, max_ingestion_time=self.max_ingestion_time)
self.log.info("Submitting %s", self.json_index_file)
hook.submit_indexing_job(json.loads(self.json_index_file))
diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py
index 2f6114d..1263788 100644
--- a/airflow/providers/apache/druid/operators/druid_check.py
+++ b/airflow/providers/apache/druid/operators/druid_check.py
@@ -58,11 +58,7 @@ class DruidCheckOperator(CheckOperator):
@apply_defaults
def __init__(
- self,
- *,
- sql: str,
- druid_broker_conn_id: str = 'druid_broker_default',
- **kwargs: Any
+ self, *, sql: str, druid_broker_conn_id: str = 'druid_broker_default', **kwargs: Any
) -> None:
super().__init__(sql=sql, **kwargs)
self.druid_broker_conn_id = druid_broker_conn_id
diff --git a/airflow/providers/apache/druid/transfers/hive_to_druid.py b/airflow/providers/apache/druid/transfers/hive_to_druid.py
index 595db0f..36d5ff4 100644
--- a/airflow/providers/apache/druid/transfers/hive_to_druid.py
+++ b/airflow/providers/apache/druid/transfers/hive_to_druid.py
@@ -84,7 +84,8 @@ class HiveToDruidOperator(BaseOperator):
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self, *,
+ self,
+ *,
sql: str,
druid_datasource: str,
ts_dim: str,
@@ -100,7 +101,7 @@ class HiveToDruidOperator(BaseOperator):
segment_granularity: str = "DAY",
hive_tblproperties: Optional[Dict[Any, Any]] = None,
job_properties: Optional[Dict[Any, Any]] = None,
- **kwargs: Any
+ **kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.sql = sql
@@ -111,9 +112,7 @@ class HiveToDruidOperator(BaseOperator):
self.target_partition_size = target_partition_size
self.query_granularity = query_granularity
self.segment_granularity = segment_granularity
- self.metric_spec = metric_spec or [{
- "name": "count",
- "type": "count"}]
+ self.metric_spec = metric_spec or [{"name": "count", "type": "count"}]
self.hive_cli_conn_id = hive_cli_conn_id
self.hadoop_dependency_coordinates = hadoop_dependency_coordinates
self.druid_ingest_conn_id = druid_ingest_conn_id
@@ -126,9 +125,7 @@ class HiveToDruidOperator(BaseOperator):
self.log.info("Extracting data from Hive")
hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_')
sql = self.sql.strip().strip(';')
- tblproperties = ''.join([", '{}' = '{}'"
- .format(k, v)
- for k, v in self.hive_tblproperties.items()])
+ tblproperties = ''.join([", '{}' = '{}'".format(k, v) for k, v in self.hive_tblproperties.items()])
hql = f"""\
SET mapred.output.compress=false;
SET hive.exec.compress.output=false;
@@ -155,10 +152,7 @@ class HiveToDruidOperator(BaseOperator):
druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id)
try:
- index_spec = self.construct_ingest_query(
- static_path=static_path,
- columns=columns,
- )
+ index_spec = self.construct_ingest_query(static_path=static_path, columns=columns,)
self.log.info("Inserting rows into Druid, hdfs path: %s", static_path)
@@ -166,15 +160,11 @@ class HiveToDruidOperator(BaseOperator):
self.log.info("Load seems to have succeeded!")
finally:
- self.log.info(
- "Cleaning up by dropping the temp Hive table %s",
- hive_table
- )
+ self.log.info("Cleaning up by dropping the temp Hive table %s", hive_table)
hql = "DROP TABLE IF EXISTS {}".format(hive_table)
hive.run_cli(hql)
- def construct_ingest_query(self, static_path: str,
- columns: List[str]) -> Dict[str, Any]:
+ def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[str, Any]:
"""
Builds an ingest query for an HDFS TSV load.
@@ -219,16 +209,13 @@ class HiveToDruidOperator(BaseOperator):
"dimensionsSpec": {
"dimensionExclusions": [],
"dimensions": dimensions, # list of names
- "spatialDimensions": []
+ "spatialDimensions": [],
},
- "timestampSpec": {
- "column": self.ts_dim,
- "format": "auto"
- },
- "format": "tsv"
- }
+ "timestampSpec": {"column": self.ts_dim, "format": "auto"},
+ "format": "tsv",
+ },
},
- "dataSource": self.druid_datasource
+ "dataSource": self.druid_datasource,
},
"tuningConfig": {
"type": "hadoop",
@@ -243,22 +230,14 @@ class HiveToDruidOperator(BaseOperator):
"numShards": num_shards,
},
},
- "ioConfig": {
- "inputSpec": {
- "paths": static_path,
- "type": "static"
- },
- "type": "hadoop"
- }
- }
+ "ioConfig": {"inputSpec": {"paths": static_path, "type": "static"}, "type": "hadoop"},
+ },
}
if self.job_properties:
- ingest_query_dict['spec']['tuningConfig']['jobProperties'] \
- .update(self.job_properties)
+ ingest_query_dict['spec']['tuningConfig']['jobProperties'].update(self.job_properties)
if self.hadoop_dependency_coordinates:
- ingest_query_dict['hadoopDependencyCoordinates'] \
- = self.hadoop_dependency_coordinates
+ ingest_query_dict['hadoopDependencyCoordinates'] = self.hadoop_dependency_coordinates
return ingest_query_dict
diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py
index 61b3772..e13a5c7 100644
--- a/airflow/providers/apache/hdfs/hooks/hdfs.py
+++ b/airflow/providers/apache/hdfs/hooks/hdfs.py
@@ -46,18 +46,17 @@ class HDFSHook(BaseHook):
:type autoconfig: bool
"""
- def __init__(self,
- hdfs_conn_id: str = 'hdfs_default',
- proxy_user: Optional[str] = None,
- autoconfig: bool = False
- ):
+ def __init__(
+ self, hdfs_conn_id: str = 'hdfs_default', proxy_user: Optional[str] = None, autoconfig: bool = False
+ ):
super().__init__()
if not snakebite_loaded:
raise ImportError(
'This HDFSHook implementation requires snakebite, but '
'snakebite is not compatible with Python 3 '
'(as of August 2015). Please use Python 2 if you require '
- 'this hook -- or help by submitting a PR!')
+ 'this hook -- or help by submitting a PR!'
+ )
self.hdfs_conn_id = hdfs_conn_id
self.proxy_user = proxy_user
self.autoconfig = autoconfig
@@ -78,29 +77,34 @@ class HDFSHook(BaseHook):
if not effective_user:
effective_user = connections[0].login
if not autoconfig:
- autoconfig = connections[0].extra_dejson.get('autoconfig',
- False)
- hdfs_namenode_principal = connections[0].extra_dejson.get(
- 'hdfs_namenode_principal')
+ autoconfig = connections[0].extra_dejson.get('autoconfig', False)
+ hdfs_namenode_principal = connections[0].extra_dejson.get('hdfs_namenode_principal')
except AirflowException:
if not autoconfig:
raise
if autoconfig:
# will read config info from $HADOOP_HOME conf files
- client = AutoConfigClient(effective_user=effective_user,
- use_sasl=use_sasl)
+ client = AutoConfigClient(effective_user=effective_user, use_sasl=use_sasl)
elif len(connections) == 1:
- client = Client(connections[0].host, connections[0].port,
- effective_user=effective_user, use_sasl=use_sasl,
- hdfs_namenode_principal=hdfs_namenode_principal)
+ client = Client(
+ connections[0].host,
+ connections[0].port,
+ effective_user=effective_user,
+ use_sasl=use_sasl,
+ hdfs_namenode_principal=hdfs_namenode_principal,
+ )
elif len(connections) > 1:
name_node = [Namenode(conn.host, conn.port) for conn in connections]
- client = HAClient(name_node, effective_user=effective_user,
- use_sasl=use_sasl,
- hdfs_namenode_principal=hdfs_namenode_principal)
+ client = HAClient(
+ name_node,
+ effective_user=effective_user,
+ use_sasl=use_sasl,
+ hdfs_namenode_principal=hdfs_namenode_principal,
+ )
else:
- raise HDFSHookException("conn_id doesn't exist in the repository "
- "and autoconfig is not specified")
+ raise HDFSHookException(
+ "conn_id doesn't exist in the repository " "and autoconfig is not specified"
+ )
return client
diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py
index a72c7b0..bc24601 100644
--- a/airflow/providers/apache/hdfs/hooks/webhdfs.py
+++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py
@@ -52,9 +52,7 @@ class WebHDFSHook(BaseHook):
:type proxy_user: str
"""
- def __init__(self, webhdfs_conn_id: str = 'webhdfs_default',
- proxy_user: Optional[str] = None
- ):
+ def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', proxy_user: Optional[str] = None):
super().__init__()
self.webhdfs_conn_id = webhdfs_conn_id
self.proxy_user = proxy_user
@@ -88,8 +86,9 @@ class WebHDFSHook(BaseHook):
self.log.error("Could not connect to %s:%s", connection.host, connection.port)
host_socket.close()
except HdfsError as hdfs_error:
- self.log.error('Read operation on namenode %s failed with error: %s',
- connection.host, hdfs_error)
+ self.log.error(
+ 'Read operation on namenode %s failed with error: %s', connection.host, hdfs_error
+ )
return None
def _get_client(self, connection: Connection) -> Any:
@@ -117,9 +116,9 @@ class WebHDFSHook(BaseHook):
status = conn.status(hdfs_path, strict=False)
return bool(status)
- def load_file(self, source: str, destination: str,
- overwrite: bool = True, parallelism: int = 1,
- **kwargs: Any) -> None:
+ def load_file(
+ self, source: str, destination: str, overwrite: bool = True, parallelism: int = 1, **kwargs: Any
+ ) -> None:
r"""
Uploads a file to HDFS.
@@ -140,9 +139,7 @@ class WebHDFSHook(BaseHook):
"""
conn = self.get_conn()
- conn.upload(hdfs_path=destination,
- local_path=source,
- overwrite=overwrite,
- n_threads=parallelism,
- **kwargs)
+ conn.upload(
+ hdfs_path=destination, local_path=source, overwrite=overwrite, n_threads=parallelism, **kwargs
+ )
self.log.debug("Uploaded file %s to %s", source, destination)
diff --git a/airflow/providers/apache/hdfs/sensors/hdfs.py b/airflow/providers/apache/hdfs/sensors/hdfs.py
index 85a8eb1..d7235dc 100644
--- a/airflow/providers/apache/hdfs/sensors/hdfs.py
+++ b/airflow/providers/apache/hdfs/sensors/hdfs.py
@@ -32,19 +32,22 @@ class HdfsSensor(BaseSensorOperator):
"""
Waits for a file or folder to land in HDFS
"""
+
template_fields = ('filepath',)
ui_color = settings.WEB_COLORS['LIGHTBLUE']
@apply_defaults
- def __init__(self,
- *,
- filepath: str,
- hdfs_conn_id: str = 'hdfs_default',
- ignored_ext: Optional[List[str]] = None,
- ignore_copying: bool = True,
- file_size: Optional[int] = None,
- hook: Type[HDFSHook] = HDFSHook,
- **kwargs: Any) -> None:
+ def __init__(
+ self,
+ *,
+ filepath: str,
+ hdfs_conn_id: str = 'hdfs_default',
+ ignored_ext: Optional[List[str]] = None,
+ ignore_copying: bool = True,
+ file_size: Optional[int] = None,
+ hook: Type[HDFSHook] = HDFSHook,
+ **kwargs: Any,
+ ) -> None:
super().__init__(**kwargs)
if ignored_ext is None:
ignored_ext = ['_COPYING_']
@@ -56,10 +59,7 @@ class HdfsSensor(BaseSensorOperator):
self.hook = hook
@staticmethod
- def filter_for_filesize(
- result: List[Dict[Any, Any]],
- size: Optional[int] = None
- ) -> List[Dict[Any, Any]]:
+ def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None) -> List[Dict[Any, Any]]:
"""
Will test the filepath result and test if its size is at least self.filesize
@@ -68,10 +68,7 @@ class HdfsSensor(BaseSensorOperator):
:return: (bool) depending on the matching criteria
"""
if size:
- log.debug(
- 'Filtering for file size >= %s in files: %s',
- size, map(lambda x: x['path'], result)
- )
+ log.debug('Filtering for file size >= %s in files: %s', size, map(lambda x: x['path'], result))
size *= settings.MEGABYTE
result = [x for x in result if x['length'] >= size]
log.debug('HdfsSensor.poke: after size filter result is %s', result)
@@ -79,9 +76,7 @@ class HdfsSensor(BaseSensorOperator):
@staticmethod
def filter_for_ignored_ext(
- result: List[Dict[Any, Any]],
- ignored_ext: List[str],
- ignore_copying: bool
+ result: List[Dict[Any, Any]], ignored_ext: List[str], ignore_copying: bool
) -> List[Dict[Any, Any]]:
"""
Will filter if instructed to do so the result to remove matching criteria
@@ -100,7 +95,8 @@ class HdfsSensor(BaseSensorOperator):
ignored_extensions_regex = re.compile(regex_builder)
log.debug(
'Filtering result for ignored extensions: %s in files %s',
- ignored_extensions_regex.pattern, map(lambda x: x['path'], result)
+ ignored_extensions_regex.pattern,
+ map(lambda x: x['path'], result),
)
result = [x for x in result if not ignored_extensions_regex.match(x['path'])]
log.debug('HdfsSensor.poke: after ext filter result is %s', result)
@@ -118,9 +114,7 @@ class HdfsSensor(BaseSensorOperator):
# here is a quick fix
result = sb_client.ls([self.filepath], include_toplevel=False)
self.log.debug('HdfsSensor.poke: result is %s', result)
- result = self.filter_for_ignored_ext(
- result, self.ignored_ext, self.ignore_copying
- )
+ result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
return bool(result)
except Exception: # pylint: disable=broad-except
@@ -134,10 +128,7 @@ class HdfsRegexSensor(HdfsSensor):
Waits for matching files by matching on regex
"""
- def __init__(self,
- regex: Pattern[str],
- *args: Any,
- **kwargs: Any) -> None:
+ def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.regex = regex
@@ -151,11 +142,12 @@ class HdfsRegexSensor(HdfsSensor):
self.log.info(
'Poking for %s to be a directory with files matching %s', self.filepath, self.regex.pattern
)
- result = [f for f in sb_client.ls([self.filepath], include_toplevel=False) if
- f['file_type'] == 'f' and
- self.regex.match(f['path'].replace('%s/' % self.filepath, ''))]
- result = self.filter_for_ignored_ext(result, self.ignored_ext,
- self.ignore_copying)
+ result = [
+ f
+ for f in sb_client.ls([self.filepath], include_toplevel=False)
+ if f['file_type'] == 'f' and self.regex.match(f['path'].replace('%s/' % self.filepath, ''))
+ ]
+ result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
return bool(result)
@@ -165,10 +157,7 @@ class HdfsFolderSensor(HdfsSensor):
Waits for a non-empty directory
"""
- def __init__(self,
- be_empty: bool = False,
- *args: Any,
- **kwargs: Any):
+ def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.be_empty = be_empty
@@ -180,8 +169,7 @@ class HdfsFolderSensor(HdfsSensor):
"""
sb_client = self.hook(self.hdfs_conn_id).get_conn()
result = sb_client.ls([self.filepath], include_toplevel=True)
- result = self.filter_for_ignored_ext(result, self.ignored_ext,
- self.ignore_copying)
+ result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
if self.be_empty:
self.log.info('Poking for filepath %s to a empty directory', self.filepath)
diff --git a/airflow/providers/apache/hdfs/sensors/web_hdfs.py b/airflow/providers/apache/hdfs/sensors/web_hdfs.py
index 8d21b3e..edc3c8b 100644
--- a/airflow/providers/apache/hdfs/sensors/web_hdfs.py
+++ b/airflow/providers/apache/hdfs/sensors/web_hdfs.py
@@ -25,20 +25,18 @@ class WebHdfsSensor(BaseSensorOperator):
"""
Waits for a file or folder to land in HDFS
"""
+
template_fields = ('filepath',)
@apply_defaults
- def __init__(self,
- *,
- filepath: str,
- webhdfs_conn_id: str = 'webhdfs_default',
- **kwargs: Any) -> None:
+ def __init__(self, *, filepath: str, webhdfs_conn_id: str = 'webhdfs_default', **kwargs: Any) -> None:
super().__init__(**kwargs)
self.filepath = filepath
self.webhdfs_conn_id = webhdfs_conn_id
def poke(self, context: Dict[Any, Any]) -> bool:
from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook
+
hook = WebHDFSHook(self.webhdfs_conn_id)
self.log.info('Poking for file %s', self.filepath)
return hook.check_for_path(hdfs_path=self.filepath)
diff --git a/airflow/providers/apache/hive/example_dags/example_twitter_dag.py b/airflow/providers/apache/hive/example_dags/example_twitter_dag.py
index 7dc03df..8c9d1f3 100644
--- a/airflow/providers/apache/hive/example_dags/example_twitter_dag.py
+++ b/airflow/providers/apache/hive/example_dags/example_twitter_dag.py
@@ -99,20 +99,14 @@ with DAG(
# is direction(from or to)_twitterHandle_date.csv
# --------------------------------------------------------------------------------
- fetch_tweets = PythonOperator(
- task_id='fetch_tweets',
- python_callable=fetchtweets
- )
+ fetch_tweets = PythonOperator(task_id='fetch_tweets', python_callable=fetchtweets)
# --------------------------------------------------------------------------------
# Clean the eight files. In this step you can get rid of or cherry pick columns
# and different parts of the text
# --------------------------------------------------------------------------------
- clean_tweets = PythonOperator(
- task_id='clean_tweets',
- python_callable=cleantweets
- )
+ clean_tweets = PythonOperator(task_id='clean_tweets', python_callable=cleantweets)
clean_tweets << fetch_tweets
@@ -122,10 +116,7 @@ with DAG(
# complicated. You can also take a look at Web Services to do such tasks
# --------------------------------------------------------------------------------
- analyze_tweets = PythonOperator(
- task_id='analyze_tweets',
- python_callable=analyzetweets
- )
+ analyze_tweets = PythonOperator(task_id='analyze_tweets', python_callable=analyzetweets)
analyze_tweets << clean_tweets
@@ -135,10 +126,7 @@ with DAG(
# it to MySQL
# --------------------------------------------------------------------------------
- hive_to_mysql = PythonOperator(
- task_id='hive_to_mysql',
- python_callable=transfertodb
- )
+ hive_to_mysql = PythonOperator(task_id='hive_to_mysql', python_callable=transfertodb)
# --------------------------------------------------------------------------------
# The following tasks are generated using for loop. The first task puts the eight
@@ -163,19 +151,21 @@ with DAG(
load_to_hdfs = BashOperator(
task_id="put_" + channel + "_to_hdfs",
- bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " +
- local_dir + file_name +
- hdfs_dir + channel + "/"
+ bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f "
+ + local_dir
+ + file_name
+ + hdfs_dir
+ + channel
+ + "/",
)
load_to_hdfs << analyze_tweets
load_to_hive = HiveOperator(
task_id="load_" + channel + "_to_hive",
- hql="LOAD DATA INPATH '" +
- hdfs_dir + channel + "/" + file_name + "' "
- "INTO TABLE " + channel + " "
- "PARTITION(dt='" + dt + "')"
+ hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' "
+ "INTO TABLE " + channel + " "
+ "PARTITION(dt='" + dt + "')",
)
load_to_hive << load_to_hdfs
load_to_hive >> hive_to_mysql
@@ -184,19 +174,21 @@ with DAG(
file_name = "from_" + channel + "_" + yesterday.strftime("%Y-%m-%d") + ".csv"
load_to_hdfs = BashOperator(
task_id="put_" + channel + "_to_hdfs",
- bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f " +
- local_dir + file_name +
- hdfs_dir + channel + "/"
+ bash_command="HADOOP_USER_NAME=hdfs hadoop fs -put -f "
+ + local_dir
+ + file_name
+ + hdfs_dir
+ + channel
+ + "/",
)
load_to_hdfs << analyze_tweets
load_to_hive = HiveOperator(
task_id="load_" + channel + "_to_hive",
- hql="LOAD DATA INPATH '" +
- hdfs_dir + channel + "/" + file_name + "' "
- "INTO TABLE " + channel + " "
- "PARTITION(dt='" + dt + "')"
+ hql="LOAD DATA INPATH '" + hdfs_dir + channel + "/" + file_name + "' "
+ "INTO TABLE " + channel + " "
+ "PARTITION(dt='" + dt + "')",
)
load_to_hive << load_to_hdfs
diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py
index 677ba3f..6b7042a 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -46,8 +46,10 @@ def get_context_from_env_var() -> Dict[Any, Any]:
:return: The context of interest.
"""
- return {format_map['default']: os.environ.get(format_map['env_var_format'], '')
- for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()}
+ return {
+ format_map['default']: os.environ.get(format_map['env_var_format'], '')
+ for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
+ }
class HiveCliHook(BaseHook):
@@ -82,7 +84,7 @@ class HiveCliHook(BaseHook):
run_as: Optional[str] = None,
mapred_queue: Optional[str] = None,
mapred_queue_priority: Optional[str] = None,
- mapred_job_name: Optional[str] = None
+ mapred_job_name: Optional[str] = None,
) -> None:
super().__init__()
conn = self.get_connection(hive_cli_conn_id)
@@ -98,10 +100,10 @@ class HiveCliHook(BaseHook):
if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES:
raise AirflowException(
"Invalid Mapred Queue Priority. Valid values are: "
- "{}".format(', '.join(HIVE_QUEUE_PRIORITIES)))
+ "{}".format(', '.join(HIVE_QUEUE_PRIORITIES))
+ )
- self.mapred_queue = mapred_queue or conf.get('hive',
- 'default_hive_mapred_queue')
+ self.mapred_queue = mapred_queue or conf.get('hive', 'default_hive_mapred_queue')
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
@@ -131,18 +133,18 @@ class HiveCliHook(BaseHook):
if self.use_beeline:
hive_bin = 'beeline'
jdbc_url = "jdbc:hive2://{host}:{port}/{schema}".format(
- host=conn.host, port=conn.port, schema=conn.schema)
+ host=conn.host, port=conn.port, schema=conn.schema
+ )
if conf.get('core', 'security') == 'kerberos':
- template = conn.extra_dejson.get(
- 'principal', "hive/_HOST@EXAMPLE.COM")
+ template = conn.extra_dejson.get('principal', "hive/_HOST@EXAMPLE.COM")
if "_HOST" in template:
- template = utils.replace_hostname_pattern(
- utils.get_components(template))
+ template = utils.replace_hostname_pattern(utils.get_components(template))
proxy_user = self._get_proxy_user()
jdbc_url += ";principal={template};{proxy_user}".format(
- template=template, proxy_user=proxy_user)
+ template=template, proxy_user=proxy_user
+ )
elif self.auth:
jdbc_url += ";auth=" + self.auth
@@ -176,17 +178,15 @@ class HiveCliHook(BaseHook):
"""
if not d:
return []
- return as_flattened_list(
- zip(["-hiveconf"] * len(d),
- ["{}={}".format(k, v) for k, v in d.items()])
- )
+ return as_flattened_list(zip(["-hiveconf"] * len(d), ["{}={}".format(k, v) for k, v in d.items()]))
- def run_cli(self,
- hql: Union[str, Text],
- schema: Optional[str] = None,
- verbose: Optional[bool] = True,
- hive_conf: Optional[Dict[Any, Any]] = None
- ) -> Any:
+ def run_cli(
+ self,
+ hql: Union[str, Text],
+ schema: Optional[str] = None,
+ verbose: Optional[bool] = True,
+ hive_conf: Optional[Dict[Any, Any]] = None,
+ ) -> Any:
"""
Run an hql statement using the hive cli. If hive_conf is specified
it should be a dict and the entries will be set as key/value pairs
@@ -222,28 +222,23 @@ class HiveCliHook(BaseHook):
hive_conf_params = self._prepare_hiveconf(env_context)
if self.mapred_queue:
hive_conf_params.extend(
- ['-hiveconf',
- 'mapreduce.job.queuename={}'
- .format(self.mapred_queue),
- '-hiveconf',
- 'mapred.job.queue.name={}'
- .format(self.mapred_queue),
- '-hiveconf',
- 'tez.queue.name={}'
- .format(self.mapred_queue)
- ])
+ [
+ '-hiveconf',
+ 'mapreduce.job.queuename={}'.format(self.mapred_queue),
+ '-hiveconf',
+ 'mapred.job.queue.name={}'.format(self.mapred_queue),
+ '-hiveconf',
+ 'tez.queue.name={}'.format(self.mapred_queue),
+ ]
+ )
if self.mapred_queue_priority:
hive_conf_params.extend(
- ['-hiveconf',
- 'mapreduce.job.priority={}'
- .format(self.mapred_queue_priority)])
+ ['-hiveconf', 'mapreduce.job.priority={}'.format(self.mapred_queue_priority)]
+ )
if self.mapred_job_name:
- hive_conf_params.extend(
- ['-hiveconf',
- 'mapred.job.name={}'
- .format(self.mapred_job_name)])
+ hive_conf_params.extend(['-hiveconf', 'mapred.job.name={}'.format(self.mapred_job_name)])
hive_cmd.extend(hive_conf_params)
hive_cmd.extend(['-f', f.name])
@@ -251,11 +246,8 @@ class HiveCliHook(BaseHook):
if verbose:
self.log.info("%s", " ".join(hive_cmd))
sub_process: Any = subprocess.Popen(
- hive_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- cwd=tmp_dir,
- close_fds=True)
+ hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
+ )
self.sub_process = sub_process
stdout = ''
while True:
@@ -284,9 +276,7 @@ class HiveCliHook(BaseHook):
if query.startswith('create table'):
create.append(query_original)
- elif query.startswith(('set ',
- 'add jar ',
- 'create temporary function')):
+ elif query.startswith(('set ', 'add jar ', 'create temporary function')):
other.append(query_original)
elif query.startswith('insert'):
insert.append(query_original)
@@ -323,7 +313,7 @@ class HiveCliHook(BaseHook):
delimiter: str = ',',
encoding: str = 'utf8',
pandas_kwargs: Any = None,
- **kwargs: Any
+ **kwargs: Any,
) -> None:
"""
Loads a pandas DataFrame into hive.
@@ -348,9 +338,7 @@ class HiveCliHook(BaseHook):
:param kwargs: passed to self.load_file
"""
- def _infer_field_types_from_df(
- df: pandas.DataFrame
- ) -> Dict[Any, Any]:
+ def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]:
dtype_kind_hive_type = {
'b': 'BOOLEAN', # boolean
'i': 'BIGINT', # signed integer
@@ -361,7 +349,7 @@ class HiveCliHook(BaseHook):
'O': 'STRING', # object
'S': 'STRING', # (byte-)string
'U': 'STRING', # Unicode
- 'V': 'STRING' # void
+ 'V': 'STRING', # void
}
order_type = OrderedDict()
@@ -377,20 +365,20 @@ class HiveCliHook(BaseHook):
if field_dict is None:
field_dict = _infer_field_types_from_df(df)
- df.to_csv(path_or_buf=f,
- sep=delimiter,
- header=False,
- index=False,
- encoding=encoding,
- date_format="%Y-%m-%d %H:%M:%S",
- **pandas_kwargs)
+ df.to_csv(
+ path_or_buf=f,
+ sep=delimiter,
+ header=False,
+ index=False,
+ encoding=encoding,
+ date_format="%Y-%m-%d %H:%M:%S",
+ **pandas_kwargs,
+ )
f.flush()
- return self.load_file(filepath=f.name,
- table=table,
- delimiter=delimiter,
- field_dict=field_dict,
- **kwargs)
+ return self.load_file(
+ filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
+ )
def load_file(
self,
@@ -402,7 +390,7 @@ class HiveCliHook(BaseHook):
overwrite: bool = True,
partition: Optional[Dict[str, Any]] = None,
recreate: bool = False,
- tblproperties: Optional[Dict[str, Any]] = None
+ tblproperties: Optional[Dict[str, Any]] = None,
) -> None:
"""
Loads a local file into Hive
@@ -444,20 +432,16 @@ class HiveCliHook(BaseHook):
if create or recreate:
if field_dict is None:
raise ValueError("Must provide a field dict when creating a table")
- fields = ",\n ".join(
- ['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()])
- hql += "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n".format(
- table=table, fields=fields)
+ fields = ",\n ".join(['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()])
+ hql += "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n".format(table=table, fields=fields)
if partition:
- pfields = ",\n ".join(
- [p + " STRING" for p in partition])
+ pfields = ",\n ".join([p + " STRING" for p in partition])
hql += "PARTITIONED BY ({pfields})\n".format(pfields=pfields)
hql += "ROW FORMAT DELIMITED\n"
hql += "FIELDS TERMINATED BY '{delimiter}'\n".format(delimiter=delimiter)
hql += "STORED AS textfile\n"
if tblproperties is not None:
- tprops = ", ".join(
- ["'{0}'='{1}'".format(k, v) for k, v in tblproperties.items()])
+ tprops = ", ".join(["'{0}'='{1}'".format(k, v) for k, v in tblproperties.items()])
hql += "TBLPROPERTIES({tprops})\n".format(tprops=tprops)
hql += ";"
self.log.info(hql)
@@ -467,8 +451,7 @@ class HiveCliHook(BaseHook):
hql += "OVERWRITE "
hql += "INTO TABLE {table} ".format(table=table)
if partition:
- pvals = ", ".join(
- ["{0}='{1}'".format(k, v) for k, v in partition.items()])
+ pvals = ", ".join(["{0}='{1}'".format(k, v) for k, v in partition.items()])
hql += "PARTITION ({pvals})".format(pvals=pvals)
# As a workaround for HIVE-10541, add a newline character
@@ -547,6 +530,7 @@ class HiveMetastoreHook(BaseHook):
return sasl_client
from thrift_sasl import TSaslClientTransport
+
transport = TSaslClientTransport(sasl_factory, "GSSAPI", conn_socket)
else:
transport = TTransport.TBufferedTransport(conn_socket)
@@ -590,16 +574,11 @@ class HiveMetastoreHook(BaseHook):
True
"""
with self.metastore as client:
- partitions = client.get_partitions_by_filter(
- schema, table, partition, 1)
+ partitions = client.get_partitions_by_filter(schema, table, partition, 1)
return bool(partitions)
- def check_for_named_partition(self,
- schema: str,
- table: str,
- partition_name: str
- ) -> Any:
+ def check_for_named_partition(self, schema: str, table: str, partition_name: str) -> Any:
"""
Checks whether a partition with a given name exists
@@ -651,9 +630,9 @@ class HiveMetastoreHook(BaseHook):
with self.metastore as client:
return client.get_databases(pattern)
- def get_partitions(self, schema: str, table_name: str,
- partition_filter: Optional[str] = None
- ) -> List[Any]:
+ def get_partitions(
+ self, schema: str, table_name: str, partition_filter: Optional[str] = None
+ ) -> List[Any]:
"""
Returns a list of all partitions in a table. Works only
for tables with less than 32767 (java short max val).
@@ -674,21 +653,23 @@ class HiveMetastoreHook(BaseHook):
else:
if partition_filter:
parts = client.get_partitions_by_filter(
- db_name=schema, tbl_name=table_name,
- filter=partition_filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+ db_name=schema,
+ tbl_name=table_name,
+ filter=partition_filter,
+ max_parts=HiveMetastoreHook.MAX_PART_COUNT,
+ )
else:
parts = client.get_partitions(
- db_name=schema, tbl_name=table_name,
- max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+ db_name=schema, tbl_name=table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT
+ )
pnames = [p.name for p in table.partitionKeys]
return [dict(zip(pnames, p.values)) for p in parts]
@staticmethod
- def _get_max_partition_from_part_specs(part_specs: List[Any],
- partition_key: Optional[str],
- filter_map: Optional[Dict[str, Any]]
- ) -> Any:
+ def _get_max_partition_from_part_specs(
+ part_specs: List[Any], partition_key: Optional[str], filter_map: Optional[Dict[str, Any]]
+ ) -> Any:
"""
Helper method to get max partition of partitions with partition_key
from part specs. key:value pair in filter_map will be used to
@@ -711,30 +692,36 @@ class HiveMetastoreHook(BaseHook):
# Assuming all specs have the same keys.
if partition_key not in part_specs[0].keys():
- raise AirflowException("Provided partition_key {} "
- "is not in part_specs.".format(partition_key))
+ raise AirflowException("Provided partition_key {} " "is not in part_specs.".format(partition_key))
is_subset = None
if filter_map:
is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys()))
if filter_map and not is_subset:
- raise AirflowException("Keys in provided filter_map {} "
- "are not subset of part_spec keys: {}"
- .format(', '.join(filter_map.keys()),
- ', '.join(part_specs[0].keys())))
+ raise AirflowException(
+ "Keys in provided filter_map {} "
+ "are not subset of part_spec keys: {}".format(
+ ', '.join(filter_map.keys()), ', '.join(part_specs[0].keys())
+ )
+ )
- candidates = [p_dict[partition_key] for p_dict in part_specs
- if filter_map is None or
- all(item in p_dict.items() for item in filter_map.items())]
+ candidates = [
+ p_dict[partition_key]
+ for p_dict in part_specs
+ if filter_map is None or all(item in p_dict.items() for item in filter_map.items())
+ ]
if not candidates:
return None
else:
return max(candidates)
- def max_partition(self, schema: str, table_name: str,
- field: Optional[str] = None,
- filter_map: Optional[Dict[Any, Any]] = None
- ) -> Any:
+ def max_partition(
+ self,
+ schema: str,
+ table_name: str,
+ field: Optional[str] = None,
+ filter_map: Optional[Dict[Any, Any]] = None,
+ ) -> Any:
"""
Returns the maximum value for all partitions with given field in a table.
If only one partition key exist in the table, the key will be used as field.
@@ -763,25 +750,19 @@ class HiveMetastoreHook(BaseHook):
if len(table.partitionKeys) == 1:
field = table.partitionKeys[0].name
elif not field:
- raise AirflowException("Please specify the field you want the max "
- "value for.")
+ raise AirflowException("Please specify the field you want the max " "value for.")
elif field not in key_name_set:
raise AirflowException("Provided field is not a partition key.")
if filter_map and not set(filter_map.keys()).issubset(key_name_set):
- raise AirflowException("Provided filter_map contains keys "
- "that are not partition key.")
+ raise AirflowException("Provided filter_map contains keys " "that are not partition key.")
- part_names = \
- client.get_partition_names(schema,
- table_name,
- max_parts=HiveMetastoreHook.MAX_PART_COUNT)
- part_specs = [client.partition_name_to_spec(part_name)
- for part_name in part_names]
+ part_names = client.get_partition_names(
+ schema, table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT
+ )
+ part_specs = [client.partition_name_to_spec(part_name) for part_name in part_names]
- return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs,
- field,
- filter_map)
+ return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, field, filter_map)
def table_exists(self, table_name: str, db: str = 'default') -> bool:
"""
@@ -820,8 +801,9 @@ class HiveMetastoreHook(BaseHook):
"""
if self.table_exists(table_name, db):
with self.metastore as client:
- self.log.info("Dropping partition of table %s.%s matching the spec: %s",
- db, table_name, part_vals)
+ self.log.info(
+ "Dropping partition of table %s.%s matching the spec: %s", db, table_name, part_vals
+ )
return client.drop_partition(db, table_name, part_vals, delete_data)
else:
self.log.info("Table %s.%s does not exist!", db, table_name)
@@ -839,12 +821,12 @@ class HiveServer2Hook(DbApiHook):
are using impala you may need to set it to false in the
``extra`` of your connection in the UI
"""
+
conn_name_attr = 'hiveserver2_conn_id'
default_conn_name = 'hiveserver2_default'
supports_autocommit = False
- def get_conn(self, schema: Optional[str] = None
- ) -> Any:
+ def get_conn(self, schema: Optional[str] = None) -> Any:
"""
Returns a Hive connection object.
"""
@@ -864,13 +846,13 @@ class HiveServer2Hook(DbApiHook):
# pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier
if auth_mechanism == 'GSSAPI':
self.log.warning(
- "Detected deprecated 'GSSAPI' for authMechanism "
- "for %s. Please use 'KERBEROS' instead",
- self.hiveserver2_conn_id # type: ignore
+ "Detected deprecated 'GSSAPI' for authMechanism " "for %s. Please use 'KERBEROS' instead",
+ self.hiveserver2_conn_id, # type: ignore
)
auth_mechanism = 'KERBEROS'
from pyhive.hive import connect
+
return connect(
host=db.host,
port=db.port,
@@ -878,14 +860,20 @@ class HiveServer2Hook(DbApiHook):
kerberos_service_name=kerberos_service_name,
username=db.login or username,
password=db.password,
- database=schema or db.schema or 'default')
+ database=schema or db.schema or 'default',
+ )
# pylint: enable=no-member
- def _get_results(self, hql: Union[str, Text, List[str]], schema: str = 'default',
- fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None) -> Any:
+ def _get_results(
+ self,
+ hql: Union[str, Text, List[str]],
+ schema: str = 'default',
+ fetch_size: Optional[int] = None,
+ hive_conf: Optional[Dict[Any, Any]] = None,
+ ) -> Any:
from pyhive.exc import ProgrammingError
+
if isinstance(hql, str):
hql = [hql]
previous_description = None
@@ -908,17 +896,19 @@ class HiveServer2Hook(DbApiHook):
cur.execute(statement)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
- if (lowered_statement.startswith('select') or
- lowered_statement.startswith('with') or
- lowered_statement.startswith('show') or
- (lowered_statement.startswith('set') and
- '=' not in lowered_statement)):
+ if (
+ lowered_statement.startswith('select')
+ or lowered_statement.startswith('with')
+ or lowered_statement.startswith('show')
+ or (lowered_statement.startswith('set') and '=' not in lowered_statement)
+ ):
description = cur.description
if previous_description and previous_description != description:
message = '''The statements are producing different descriptions:
Current: {}
- Previous: {}'''.format(repr(description),
- repr(previous_description))
+ Previous: {}'''.format(
+ repr(description), repr(previous_description)
+ )
raise ValueError(message)
elif not previous_description:
previous_description = description
@@ -931,10 +921,13 @@ class HiveServer2Hook(DbApiHook):
except ProgrammingError:
self.log.debug("get_results returned no records")
- def get_results(self, hql: Union[str, Text], schema: str = 'default',
- fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None
- ) -> Dict[str, Any]:
+ def get_results(
+ self,
+ hql: Union[str, Text],
+ schema: str = 'default',
+ fetch_size: Optional[int] = None,
+ hive_conf: Optional[Dict[Any, Any]] = None,
+ ) -> Dict[str, Any]:
"""
Get results of the provided hql in target schema.
@@ -949,13 +942,9 @@ class HiveServer2Hook(DbApiHook):
:return: results of hql execution, dict with data (list of results) and header
:rtype: dict
"""
- results_iter = self._get_results(hql, schema,
- fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
- results = {
- 'data': list(results_iter),
- 'header': header
- }
+ results = {'data': list(results_iter), 'header': header}
return results
def to_csv(
@@ -967,7 +956,7 @@ class HiveServer2Hook(DbApiHook):
lineterminator: str = '\r\n',
output_header: bool = True,
fetch_size: int = 1000,
- hive_conf: Optional[Dict[Any, Any]] = None
+ hive_conf: Optional[Dict[Any, Any]] = None,
) -> None:
"""
Execute hql in target schema and write results to a csv file.
@@ -991,17 +980,13 @@ class HiveServer2Hook(DbApiHook):
"""
- results_iter = self._get_results(hql, schema,
- fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
message = None
i = 0
with open(csv_filepath, 'wb') as file:
- writer = csv.writer(file,
- delimiter=delimiter,
- lineterminator=lineterminator,
- encoding='utf-8')
+ writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding='utf-8')
try:
if output_header:
self.log.debug('Cursor description is %s', header)
@@ -1021,10 +1006,9 @@ class HiveServer2Hook(DbApiHook):
self.log.info("Done. Loaded a total of %s rows.", i)
- def get_records(self, hql: Union[str, Text],
- schema: str = 'default',
- hive_conf: Optional[Dict[Any, Any]] = None
- ) -> Any:
+ def get_records(
+ self, hql: Union[str, Text], schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None
+ ) -> Any:
"""
Get a set of records from a Hive query.
@@ -1044,11 +1028,13 @@ class HiveServer2Hook(DbApiHook):
"""
return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data']
- def get_pandas_df(self, hql: Union[str, Text], # type: ignore
- schema: str = 'default',
- hive_conf: Optional[Dict[Any, Any]] = None,
- **kwargs
- ) -> pandas.DataFrame:
+ def get_pandas_df( # type: ignore
+ self,
+ hql: Union[str, Text],
+ schema: str = 'default',
+ hive_conf: Optional[Dict[Any, Any]] = None,
+ **kwargs,
+ ) -> pandas.DataFrame:
"""
Get a pandas dataframe from a Hive query
diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py
index 48d7574..1db5e99 100644
--- a/airflow/providers/apache/hive/operators/hive.py
+++ b/airflow/providers/apache/hive/operators/hive.py
@@ -62,26 +62,37 @@ class HiveOperator(BaseOperator):
:type mapred_job_name: str
"""
- template_fields = ('hql', 'schema', 'hive_cli_conn_id', 'mapred_queue',
- 'hiveconfs', 'mapred_job_name', 'mapred_queue_priority')
- template_ext = ('.hql', '.sql',)
+ template_fields = (
+ 'hql',
+ 'schema',
+ 'hive_cli_conn_id',
+ 'mapred_queue',
+ 'hiveconfs',
+ 'mapred_job_name',
+ 'mapred_queue_priority',
+ )
+ template_ext = (
+ '.hql',
+ '.sql',
+ )
ui_color = '#f0e4ec'
# pylint: disable=too-many-arguments
@apply_defaults
def __init__(
- self, *,
- hql: str,
- hive_cli_conn_id: str = 'hive_cli_default',
- schema: str = 'default',
- hiveconfs: Optional[Dict[Any, Any]] = None,
- hiveconf_jinja_translate: bool = False,
- script_begin_tag: Optional[str] = None,
- run_as_owner: bool = False,
- mapred_queue: Optional[str] = None,
- mapred_queue_priority: Optional[str] = None,
- mapred_job_name: Optional[str] = None,
- **kwargs: Any
+ self,
+ *,
+ hql: str,
+ hive_cli_conn_id: str = 'hive_cli_default',
+ schema: str = 'default',
+ hiveconfs: Optional[Dict[Any, Any]] = None,
+ hiveconf_jinja_translate: bool = False,
+ script_begin_tag: Optional[str] = None,
+ run_as_owner: bool = False,
+ mapred_queue: Optional[str] = None,
+ mapred_queue_priority: Optional[str] = None,
+ mapred_job_name: Optional[str] = None,
+ **kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.hql = hql
@@ -97,8 +108,10 @@ class HiveOperator(BaseOperator):
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
self.mapred_job_name_template = conf.get(
- 'hive', 'mapred_job_name_template',
- fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}")
+ 'hive',
+ 'mapred_job_name_template',
+ fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}",
+ )
# assigned lazily - just for consistency we can create the attribute with a
# `None` initial value, later it will be populated by the execute method.
@@ -115,12 +128,12 @@ class HiveOperator(BaseOperator):
run_as=self.run_as,
mapred_queue=self.mapred_queue,
mapred_queue_priority=self.mapred_queue_priority,
- mapred_job_name=self.mapred_job_name)
+ mapred_job_name=self.mapred_job_name,
+ )
def prepare_template(self) -> None:
if self.hiveconf_jinja_translate:
- self.hql = re.sub(
- r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql)
+ self.hql = re.sub(r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{ \g<3> }}", self.hql)
if self.script_begin_tag and self.script_begin_tag in self.hql:
self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:])
@@ -131,10 +144,12 @@ class HiveOperator(BaseOperator):
# set the mapred_job_name if it's not set with dag, task, execution time info
if not self.mapred_job_name:
ti = context['ti']
- self.hook.mapred_job_name = self.mapred_job_name_template\
- .format(dag_id=ti.dag_id, task_id=ti.task_id,
- execution_date=ti.execution_date.isoformat(),
- hostname=ti.hostname.split('.')[0])
+ self.hook.mapred_job_name = self.mapred_job_name_template.format(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ execution_date=ti.execution_date.isoformat(),
+ hostname=ti.hostname.split('.')[0],
+ )
if self.hiveconf_jinja_translate:
self.hiveconfs = context_to_airflow_vars(context)
@@ -160,6 +175,7 @@ class HiveOperator(BaseOperator):
"""
Reset airflow environment variables to prevent existing ones from impacting behavior.
"""
- blank_env_vars = {value['env_var_format']: '' for value in
- operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()}
+ blank_env_vars = {
+ value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
+ }
os.environ.update(blank_env_vars)
diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py
index 6fc689e..4dfef2c 100644
--- a/airflow/providers/apache/hive/operators/hive_stats.py
+++ b/airflow/providers/apache/hive/operators/hive_stats.py
@@ -63,24 +63,25 @@ class HiveStatsCollectionOperator(BaseOperator):
ui_color = '#aff7a6'
@apply_defaults
- def __init__(self, *,
- table: str,
- partition: Any,
- extra_exprs: Optional[Dict[str, Any]] = None,
- excluded_columns: Optional[List[str]] = None,
- assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None,
- metastore_conn_id: str = 'metastore_default',
- presto_conn_id: str = 'presto_default',
- mysql_conn_id: str = 'airflow_db',
- **kwargs: Any
- ) -> None:
+ def __init__(
+ self,
+ *,
+ table: str,
+ partition: Any,
+ extra_exprs: Optional[Dict[str, Any]] = None,
+ excluded_columns: Optional[List[str]] = None,
+ assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None,
+ metastore_conn_id: str = 'metastore_default',
+ presto_conn_id: str = 'presto_default',
+ mysql_conn_id: str = 'airflow_db',
+ **kwargs: Any,
+ ) -> None:
if 'col_blacklist' in kwargs:
warnings.warn(
'col_blacklist kwarg passed to {c} (task_id: {t}) is deprecated, please rename it to '
- 'excluded_columns instead'.format(
- c=self.__class__.__name__, t=kwargs.get('task_id')),
+ 'excluded_columns instead'.format(c=self.__class__.__name__, t=kwargs.get('task_id')),
category=FutureWarning,
- stacklevel=2
+ stacklevel=2,
)
excluded_columns = kwargs.pop('col_blacklist')
super().__init__(**kwargs)
@@ -121,9 +122,7 @@ class HiveStatsCollectionOperator(BaseOperator):
table = metastore.get_table(table_name=self.table)
field_types = {col.name: col.type for col in table.sd.cols}
- exprs: Any = {
- ('', 'count'): 'COUNT(*)'
- }
+ exprs: Any = {('', 'count'): 'COUNT(*)'}
for col, col_type in list(field_types.items()):
if self.assignment_func:
assign_exprs = self.assignment_func(col, col_type)
@@ -134,14 +133,13 @@ class HiveStatsCollectionOperator(BaseOperator):
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
- exprs_str = ",\n ".join([
- v + " AS " + k[0] + '__' + k[1]
- for k, v in exprs.items()])
+ exprs_str = ",\n ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()])
where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()]
where_clause = " AND\n ".join(where_clause_)
sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
- exprs_str=exprs_str, table=self.table, where_clause=where_clause)
+ exprs_str=exprs_str, table=self.table, where_clause=where_clause
+ )
presto = PrestoHook(presto_conn_id=self.presto_conn_id)
self.log.info('Executing SQL check: %s', sql)
@@ -161,7 +159,9 @@ class HiveStatsCollectionOperator(BaseOperator):
partition_repr='{part_json}' AND
dttm='{dttm}'
LIMIT 1;
- """.format(table=self.table, part_json=part_json, dttm=self.dttm)
+ """.format(
+ table=self.table, part_json=part_json, dttm=self.dttm
+ )
if mysql.get_records(sql):
sql = """
DELETE FROM hive_stats
@@ -169,22 +169,17 @@ class HiveStatsCollectionOperator(BaseOperator):
table_name='{table}' AND
partition_repr='{part_json}' AND
dttm='{dttm}';
- """.format(table=self.table, part_json=part_json, dttm=self.dttm)
+ """.format(
+ table=self.table, part_json=part_json, dttm=self.dttm
+ )
mysql.run(sql)
self.log.info("Pivoting and loading cells into the Airflow db")
- rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1])
- for r in zip(exprs, row)]
+ rows = [
+ (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)
+ ]
mysql.insert_rows(
table='hive_stats',
rows=rows,
- target_fields=[
- 'ds',
- 'dttm',
- 'table_name',
- 'partition_repr',
- 'col',
- 'metric',
- 'value',
- ]
+ target_fields=['ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',],
)
diff --git a/airflow/providers/apache/hive/sensors/hive_partition.py b/airflow/providers/apache/hive/sensors/hive_partition.py
index 8e1b827..15843b8 100644
--- a/airflow/providers/apache/hive/sensors/hive_partition.py
+++ b/airflow/providers/apache/hive/sensors/hive_partition.py
@@ -42,19 +42,26 @@ class HivePartitionSensor(BaseSensorOperator):
connection id
:type metastore_conn_id: str
"""
- template_fields = ('schema', 'table', 'partition',)
+
+ template_fields = (
+ 'schema',
+ 'table',
+ 'partition',
+ )
ui_color = '#C5CAE9'
@apply_defaults
- def __init__(self, *,
- table: str,
- partition: Optional[str] = "ds='{{ ds }}'",
- metastore_conn_id: str = 'metastore_default',
- schema: str = 'default',
- poke_interval: int = 60 * 3,
- **kwargs: Any):
- super().__init__(
- poke_interval=poke_interval, **kwargs)
+ def __init__(
+ self,
+ *,
+ table: str,
+ partition: Optional[str] = "ds='{{ ds }}'",
+ metastore_conn_id: str = 'metastore_default',
+ schema: str = 'default',
+ poke_interval: int = 60 * 3,
+ **kwargs: Any,
+ ):
+ super().__init__(poke_interval=poke_interval, **kwargs)
if not partition:
partition = "ds='{{ ds }}'"
self.metastore_conn_id = metastore_conn_id
@@ -65,11 +72,7 @@ class HivePartitionSensor(BaseSensorOperator):
def poke(self, context: Dict[str, Any]) -> bool:
if '.' in self.table:
self.schema, self.table = self.table.split('.')
- self.log.info(
- 'Poking for table %s.%s, partition %s', self.schema, self.table, self.partition
- )
+ self.log.info('Poking for table %s.%s, partition %s', self.schema, self.table, self.partition)
if not hasattr(self, 'hook'):
- hook = HiveMetastoreHook(
- metastore_conn_id=self.metastore_conn_id)
- return hook.check_for_partition(
- self.schema, self.table, self.partition)
+ hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
+ return hook.check_for_partition(self.schema, self.table, self.partition)
diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py
index 1e54440..31376ad 100644
--- a/airflow/providers/apache/hive/sensors/metastore_partition.py
+++ b/airflow/providers/apache/hive/sensors/metastore_partition.py
@@ -41,16 +41,20 @@ class MetastorePartitionSensor(SqlSensor):
:param mysql_conn_id: a reference to the MySQL conn_id for the metastore
:type mysql_conn_id: str
"""
+
template_fields = ('partition_name', 'table', 'schema')
ui_color = '#8da7be'
@apply_defaults
- def __init__(self, *,
- table: str,
- partition_name: str,
- schema: str = "default",
- mysql_conn_id: str = "metastore_mysql",
- **kwargs: Any):
+ def __init__(
+ self,
+ *,
+ table: str,
+ partition_name: str,
+ schema: str = "default",
+ mysql_conn_id: str = "metastore_mysql",
+ **kwargs: Any,
+ ):
self.partition_name = partition_name
self.table = table
@@ -78,5 +82,7 @@ class MetastorePartitionSensor(SqlSensor):
B0.TBL_NAME = '{self.table}' AND
C0.NAME = '{self.schema}' AND
A0.PART_NAME = '{self.partition_name}';
- """.format(self=self)
+ """.format(
+ self=self
+ )
return super().poke(context)
diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py
index f69e2b2..23d9466 100644
--- a/airflow/providers/apache/hive/sensors/named_hive_partition.py
+++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py
@@ -42,14 +42,16 @@ class NamedHivePartitionSensor(BaseSensorOperator):
ui_color = '#8d99ae'
@apply_defaults
- def __init__(self, *,
- partition_names: List[str],
- metastore_conn_id: str = 'metastore_default',
- poke_interval: int = 60 * 3,
- hook: Any = None,
- **kwargs: Any):
- super().__init__(
- poke_interval=poke_interval, **kwargs)
+ def __init__(
+ self,
+ *,
+ partition_names: List[str],
+ metastore_conn_id: str = 'metastore_default',
+ poke_interval: int = 60 * 3,
+ hook: Any = None,
+ **kwargs: Any,
+ ):
+ super().__init__(poke_interval=poke_interval, **kwargs)
self.next_index_to_poke = 0
if isinstance(partition_names, str):
@@ -74,8 +76,7 @@ class NamedHivePartitionSensor(BaseSensorOperator):
schema, table_partition = first_split
second_split = table_partition.split('/', 1)
if len(second_split) == 1:
- raise ValueError('Could not parse ' + partition +
- 'into table, partition')
+ raise ValueError('Could not parse ' + partition + 'into table, partition')
else:
table, partition = second_split
return schema, table, partition
@@ -84,14 +85,13 @@ class NamedHivePartitionSensor(BaseSensorOperator):
"""Check for a named partition."""
if not self.hook:
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
- self.hook = HiveMetastoreHook(
- metastore_conn_id=self.metastore_conn_id)
+
+ self.hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
schema, table, partition = self.parse_partition_name(partition)
self.log.info('Poking for %s.%s/%s', schema, table, partition)
- return self.hook.check_for_named_partition(
- schema, table, partition)
+ return self.hook.check_for_named_partition(schema, table, partition)
def poke(self, context: Dict[str, Any]) -> bool:
diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
index 724c791..11d81ec 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
@@ -67,16 +67,19 @@ class HiveToMySqlOperator(BaseOperator):
ui_color = '#a0e08c'
@apply_defaults
- def __init__(self, *,
- sql: str,
- mysql_table: str,
- hiveserver2_conn_id: str = 'hiveserver2_default',
- mysql_conn_id: str = 'mysql_default',
- mysql_preoperator: Optional[str] = None,
- mysql_postoperator: Optional[str] = None,
- bulk_load: bool = False,
- hive_conf: Optional[Dict] = None,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ sql: str,
+ mysql_table: str,
+ hiveserver2_conn_id: str = 'hiveserver2_default',
+ mysql_conn_id: str = 'mysql_default',
+ mysql_preoperator: Optional[str] = None,
+ mysql_postoperator: Optional[str] = None,
+ bulk_load: bool = False,
+ hive_conf: Optional[Dict] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.sql = sql
self.mysql_table = mysql_table
@@ -96,12 +99,14 @@ class HiveToMySqlOperator(BaseOperator):
hive_conf.update(self.hive_conf)
if self.bulk_load:
tmp_file = NamedTemporaryFile()
- hive.to_csv(self.sql,
- tmp_file.name,
- delimiter='\t',
- lineterminator='\n',
- output_header=False,
- hive_conf=hive_conf)
+ hive.to_csv(
+ self.sql,
+ tmp_file.name,
+ delimiter='\t',
+ lineterminator='\n',
+ output_header=False,
+ hive_conf=hive_conf,
+ )
else:
hive_results = hive.get_records(self.sql, hive_conf=hive_conf)
diff --git a/airflow/providers/apache/hive/transfers/hive_to_samba.py b/airflow/providers/apache/hive/transfers/hive_to_samba.py
index 5f08b83..dc93297 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_samba.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_samba.py
@@ -45,15 +45,21 @@ class HiveToSambaOperator(BaseOperator):
"""
template_fields = ('hql', 'destination_filepath')
- template_ext = ('.hql', '.sql',)
+ template_ext = (
+ '.hql',
+ '.sql',
+ )
@apply_defaults
- def __init__(self, *,
- hql: str,
- destination_filepath: str,
- samba_conn_id: str = 'samba_default',
- hiveserver2_conn_id: str = 'hiveserver2_default',
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ hql: str,
+ destination_filepath: str,
+ samba_conn_id: str = 'samba_default',
+ hiveserver2_conn_id: str = 'hiveserver2_default',
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.hiveserver2_conn_id = hiveserver2_conn_id
self.samba_conn_id = samba_conn_id
diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py
index 01a9327..8f32ca2 100644
--- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py
@@ -76,17 +76,20 @@ class MsSqlToHiveOperator(BaseOperator):
ui_color = '#a0e08c'
@apply_defaults
- def __init__(self, *,
- sql: str,
- hive_table: str,
- create: bool = True,
- recreate: bool = False,
- partition: Optional[Dict] = None,
- delimiter: str = chr(1),
- mssql_conn_id: str = 'mssql_default',
- hive_cli_conn_id: str = 'hive_cli_default',
- tblproperties: Optional[Dict] = None,
- **kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ sql: str,
+ hive_table: str,
+ create: bool = True,
+ recreate: bool = False,
+ partition: Optional[Dict] = None,
+ delimiter: str = chr(1),
+ mssql_conn_id: str = 'mssql_default',
+ hive_cli_conn_id: str = 'hive_cli_default',
+ tblproperties: Optional[Dict] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.sql = sql
self.hive_table = hive_table
@@ -138,4 +141,5 @@ class MsSqlToHiveOperator(BaseOperator):
partition=self.partition,
delimiter=self.delimiter,
recreate=self.recreate,
- tblproperties=self.tblproperties)
+ tblproperties=self.tblproperties,
+ )
diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
index 99650ec..25aa802 100644
--- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
@@ -85,21 +85,22 @@ class MySqlToHiveOperator(BaseOperator):
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self,
- *,
- sql: str,
- hive_table: str,
- create: bool = True,
- recreate: bool = False,
- partition: Optional[Dict] = None,
- delimiter: str = chr(1),
- quoting: Optional[str] = None,
- quotechar: str = '"',
- escapechar: Optional[str] = None,
- mysql_conn_id: str = 'mysql_default',
- hive_cli_conn_id: str = 'hive_cli_default',
- tblproperties: Optional[Dict] = None,
- **kwargs) -> None:
+ self,
+ *,
+ sql: str,
+ hive_table: str,
+ create: bool = True,
+ recreate: bool = False,
+ partition: Optional[Dict] = None,
+ delimiter: str = chr(1),
+ quoting: Optional[str] = None,
+ quotechar: str = '"',
+ escapechar: Optional[str] = None,
+ mysql_conn_id: str = 'mysql_default',
+ hive_cli_conn_id: str = 'hive_cli_default',
+ tblproperties: Optional[Dict] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.sql = sql
self.hive_table = hive_table
@@ -146,11 +147,14 @@ class MySqlToHiveOperator(BaseOperator):
cursor = conn.cursor()
cursor.execute(self.sql)
with NamedTemporaryFile("wb") as f:
- csv_writer = csv.writer(f, delimiter=self.delimiter,
- quoting=self.quoting,
- quotechar=self.quotechar,
- escapechar=self.escapechar,
- encoding="utf-8")
+ csv_writer = csv.writer(
+ f,
+ delimiter=self.delimiter,
+ quoting=self.quoting,
+ quotechar=self.quotechar,
+ escapechar=self.escapechar,
+ encoding="utf-8",
+ )
field_dict = OrderedDict()
for field in cursor.description:
field_dict[field[0]] = self.type_map(field[1])
@@ -167,4 +171,5 @@ class MySqlToHiveOperator(BaseOperator):
partition=self.partition,
delimiter=self.delimiter,
recreate=self.recreate,
- tblproperties=self.tblproperties)
+ tblproperties=self.tblproperties,
+ )
diff --git a/airflow/providers/apache/hive/transfers/s3_to_hive.py b/airflow/providers/apache/hive/transfers/s3_to_hive.py
index 6c730a0..844777e 100644
--- a/airflow/providers/apache/hive/transfers/s3_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/s3_to_hive.py
@@ -107,25 +107,26 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
- self,
- *,
- s3_key: str,
- field_dict: Dict,
- hive_table: str,
- delimiter: str = ',',
- create: bool = True,
- recreate: bool = False,
- partition: Optional[Dict] = None,
- headers: bool = False,
- check_headers: bool = False,
- wildcard_match: bool = False,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- hive_cli_conn_id: str = 'hive_cli_default',
- input_compressed: bool = False,
- tblproperties: Optional[Dict] = None,
- select_expression: Optional[str] = None,
- **kwargs) -> None:
+ self,
+ *,
+ s3_key: str,
+ field_dict: Dict,
+ hive_table: str,
+ delimiter: str = ',',
+ create: bool = True,
+ recreate: bool = False,
+ partition: Optional[Dict] = None,
+ headers: bool = False,
+ check_headers: bool = False,
+ wildcard_match: bool = False,
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[bool, str]] = None,
+ hive_cli_conn_id: str = 'hive_cli_default',
+ input_compressed: bool = False,
+ tblproperties: Optional[Dict] = None,
+ select_expression: Optional[str] = None,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.s3_key = s3_key
self.field_dict = field_dict
@@ -144,10 +145,8 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
self.tblproperties = tblproperties
self.select_expression = select_expression
- if (self.check_headers and
- not (self.field_dict is not None and self.headers)):
- raise AirflowException("To check_headers provide " +
- "field_dict and headers")
+ if self.check_headers and not (self.field_dict is not None and self.headers):
+ raise AirflowException("To check_headers provide " + "field_dict and headers")
def execute(self, context):
# Downloading file from S3
@@ -165,18 +164,13 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
s3_key_object = s3_hook.get_key(self.s3_key)
_, file_ext = os.path.splitext(s3_key_object.key)
- if (self.select_expression and self.input_compressed and
- file_ext.lower() != '.gz'):
- raise AirflowException("GZIP is the only compression " +
- "format Amazon S3 Select supports")
+ if self.select_expression and self.input_compressed and file_ext.lower() != '.gz':
+ raise AirflowException("GZIP is the only compression " + "format Amazon S3 Select supports")
- with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
- NamedTemporaryFile(mode="wb",
- dir=tmp_dir,
- suffix=file_ext) as f:
- self.log.info(
- "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
- )
+ with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile(
+ mode="wb", dir=tmp_dir, suffix=file_ext
+ ) as f:
+ self.log.info("Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name)
if self.select_expression:
option = {}
if self.headers:
@@ -192,7 +186,7 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
bucket_name=s3_key_object.bucket_name,
key=s3_key_object.key,
expression=self.select_expression,
- input_serialization=input_serialization
+ input_serialization=input_serialization,
)
f.write(content.encode("utf-8"))
else:
@@ -209,14 +203,13 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
partition=self.partition,
delimiter=self.delimiter,
recreate=self.recreate,
- tblproperties=self.tblproperties)
+ tblproperties=self.tblproperties,
+ )
else:
# Decompressing file
if self.input_compressed:
self.log.info("Uncompressing file %s", f.name)
- fn_uncompressed = uncompress_file(f.name,
- file_ext,
- tmp_dir)
+ fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir)
self.log.info("Uncompressed to %s", fn_uncompressed)
# uncompressed file available now so deleting
# compressed file to save disk space
@@ -233,20 +226,19 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
# Deleting top header row
self.log.info("Removing header from file %s", fn_uncompressed)
- headless_file = (
- self._delete_top_row_and_compress(fn_uncompressed,
- file_ext,
- tmp_dir))
+ headless_file = self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)
self.log.info("Headless file %s", headless_file)
self.log.info("Loading file %s into Hive", headless_file)
- hive_hook.load_file(headless_file,
- self.hive_table,
- field_dict=self.field_dict,
- create=self.create,
- partition=self.partition,
- delimiter=self.delimiter,
- recreate=self.recreate,
- tblproperties=self.tblproperties)
+ hive_hook.load_file(
+ headless_file,
+ self.hive_table,
+ field_dict=self.field_dict,
+ create=self.create,
+ partition=self.partition,
+ delimiter=self.delimiter,
+ recreate=self.recreate,
+ tblproperties=self.tblproperties,
+ )
def _get_top_row_as_list(self, file_name):
with open(file_name, 'rt') as file:
@@ -263,22 +255,19 @@ class S3ToHiveOperator(BaseOperator): # pylint: disable=too-many-instance-attri
"Headers count mismatch File headers:\n %s\nField names: \n %s\n", header_list, field_names
)
return False
- test_field_match = [h1.lower() == h2.lower()
- for h1, h2 in zip(header_list, field_names)]
+ test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)]
if not all(test_field_match):
self.log.warning(
"Headers do not match field names File headers:\n %s\nField names: \n %s\n",
- header_list, field_names
+ header_list,
+ field_names,
)
return False
else:
return True
@staticmethod
- def _delete_top_row_and_compress(
- input_file_name,
- output_file_ext,
- dest_dir):
+ def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir):
# When output_file_ext is not defined, file is not compressed
open_fn = open
if output_file_ext.lower() == '.gz':
diff --git a/airflow/providers/apache/hive/transfers/vertica_to_hive.py b/airflow/providers/apache/hive/transfers/vertica_to_hive.py
index 02a4f80..66c9790 100644
--- a/airflow/providers/apache/hive/transfers/vertica_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/vertica_to_hive.py
@@ -73,17 +73,18 @@ class VerticaToHiveOperator(BaseOperator):
@apply_defaults
def __init__(
- self,
- *,
- sql,
- hive_table,
- create=True,
- recreate=False,
- partition=None,
- delimiter=chr(1),
- vertica_conn_id='vertica_default',
- hive_cli_conn_id='hive_cli_default',
- **kwargs):
+ self,
+ *,
+ sql,
+ hive_table,
+ create=True,
+ recreate=False,
+ partition=None,
+ delimiter=chr(1),
+ vertica_conn_id='vertica_default',
+ hive_cli_conn_id='hive_cli_default',
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.sql = sql
self.hive_table = hive_table
@@ -127,8 +128,7 @@ class VerticaToHiveOperator(BaseOperator):
for field in cursor.description:
col_count += 1
col_position = "Column{position}".format(position=col_count)
- field_dict[col_position if field[0] == '' else field[0]] = \
- self.type_map(field[1])
+ field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor.iterate())
f.flush()
cursor.close()
@@ -141,4 +141,5 @@ class VerticaToHiveOperator(BaseOperator):
create=self.create,
partition=self.partition,
delimiter=self.delimiter,
- recreate=self.recreate)
+ recreate=self.recreate,
+ )
diff --git a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py
index 8ff685a..f7901d4 100644
--- a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py
+++ b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py
@@ -34,7 +34,7 @@ dag = DAG(
default_args=args,
schedule_interval=None,
start_date=days_ago(1),
- tags=['example']
+ tags=['example'],
)
@@ -49,11 +49,7 @@ def gen_build_time(**kwargs):
ti.xcom_push(key='date_end', value='1325433600000')
-gen_build_time_task = PythonOperator(
- python_callable=gen_build_time,
- task_id='gen_build_time',
- dag=dag
-)
+gen_build_time_task = PythonOperator(python_callable=gen_build_time, task_id='gen_build_time', dag=dag)
build_task1 = KylinCubeOperator(
task_id="kylin_build_1",
diff --git a/airflow/providers/apache/kylin/hooks/kylin.py b/airflow/providers/apache/kylin/hooks/kylin.py
index 8a880e3..991815d 100644
--- a/airflow/providers/apache/kylin/hooks/kylin.py
+++ b/airflow/providers/apache/kylin/hooks/kylin.py
@@ -33,11 +33,13 @@ class KylinHook(BaseHook):
:param dsn: dsn
:type dsn: Optional[str]
"""
- def __init__(self,
- kylin_conn_id: Optional[str] = 'kylin_default',
- project: Optional[str] = None,
- dsn: Optional[str] = None
- ):
+
+ def __init__(
+ self,
+ kylin_conn_id: Optional[str] = 'kylin_default',
+ project: Optional[str] = None,
+ dsn: Optional[str] = None,
+ ):
super().__init__()
self.kylin_conn_id = kylin_conn_id
self.project = project
@@ -49,9 +51,14 @@ class KylinHook(BaseHook):
return kylinpy.create_kylin(self.dsn)
else:
self.project = self.project if self.project else conn.schema
- return kylinpy.Kylin(conn.host, username=conn.login,
- password=conn.password, port=conn.port,
- project=self.project, **conn.extra_dejson)
+ return kylinpy.Kylin(
+ conn.host,
+ username=conn.login,
+ password=conn.password,
+ port=conn.port,
+ project=self.project,
+ **conn.extra_dejson,
+ )
def cube_run(self, datasource_name, op, **op_args):
"""
diff --git a/airflow/providers/apache/kylin/operators/kylin_cube.py b/airflow/providers/apache/kylin/operators/kylin_cube.py
index 5a8cdbc..a732689 100644
--- a/airflow/providers/apache/kylin/operators/kylin_cube.py
+++ b/airflow/providers/apache/kylin/operators/kylin_cube.py
@@ -87,31 +87,50 @@ class KylinCubeOperator(BaseOperator):
:type eager_error_status: tuple
"""
- template_fields = ('project', 'cube', 'dsn', 'command', 'start_time', 'end_time',
- 'segment_name', 'offset_start', 'offset_end')
+ template_fields = (
+ 'project',
+ 'cube',
+ 'dsn',
+ 'command',
+ 'start_time',
+ 'end_time',
+ 'segment_name',
+ 'offset_start',
+ 'offset_end',
+ )
ui_color = '#E79C46'
- build_command = {'fullbuild', 'build', 'merge', 'refresh', 'build_streaming',
- 'merge_streaming', 'refresh_streaming'}
+ build_command = {
+ 'fullbuild',
+ 'build',
+ 'merge',
+ 'refresh',
+ 'build_streaming',
+ 'merge_streaming',
+ 'refresh_streaming',
+ }
jobs_end_status = {"FINISHED", "ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"}
# pylint: disable=too-many-arguments,inconsistent-return-statements
@apply_defaults
- def __init__(self, *,
- kylin_conn_id: Optional[str] = 'kylin_default',
- project: Optional[str] = None,
- cube: Optional[str] = None,
- dsn: Optional[str] = None,
- command: Optional[str] = None,
- start_time: Optional[str] = None,
- end_time: Optional[str] = None,
- offset_start: Optional[str] = None,
- offset_end: Optional[str] = None,
- segment_name: Optional[str] = None,
- is_track_job: Optional[bool] = False,
- interval: int = 60,
- timeout: int = 60 * 60 * 24,
- eager_error_status=("ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"),
- **kwargs):
+ def __init__(
+ self,
+ *,
+ kylin_conn_id: Optional[str] = 'kylin_default',
+ project: Optional[str] = None,
+ cube: Optional[str] = None,
+ dsn: Optional[str] = None,
+ command: Optional[str] = None,
+ start_time: Optional[str] = None,
+ end_time: Optional[str] = None,
+ offset_start: Optional[str] = None,
+ offset_end: Optional[str] = None,
+ segment_name: Optional[str] = None,
+ is_track_job: Optional[bool] = False,
+ interval: int = 60,
+ timeout: int = 60 * 60 * 24,
+ eager_error_status=("ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"),
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.kylin_conn_id = kylin_conn_id
self.project = project
@@ -135,15 +154,18 @@ class KylinCubeOperator(BaseOperator):
_support_invoke_command = kylinpy.CubeSource.support_invoke_command
if self.command.lower() not in _support_invoke_command:
- raise AirflowException('Kylin:Command {} can not match kylin command list {}'.format(
- self.command, _support_invoke_command))
+ raise AirflowException(
+ 'Kylin:Command {} can not match kylin command list {}'.format(
+ self.command, _support_invoke_command
+ )
+ )
kylinpy_params = {
'start': datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None,
'end': datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None,
'name': self.segment_name,
'offset_start': int(self.offset_start) if self.offset_start else None,
- 'offset_end': int(self.offset_end) if self.offset_end else None
+ 'offset_end': int(self.offset_end) if self.offset_end else None,
}
rsp_data = _hook.cube_run(self.cube, self.command.lower(), **kylinpy_params)
if self.is_track_job and self.command.lower() in self.build_command:
@@ -162,8 +184,7 @@ class KylinCubeOperator(BaseOperator):
job_status = _hook.get_job_status(job_id)
self.log.info('Kylin job status is %s ', job_status)
if job_status in self.jobs_error_status:
- raise AirflowException(
- 'Kylin job {} status {} is error '.format(job_id, job_status))
+ raise AirflowException('Kylin job {} status {} is error '.format(job_id, job_status))
if self.do_xcom_push:
return rsp_data
diff --git a/airflow/providers/apache/livy/example_dags/example_livy.py b/airflow/providers/apache/livy/example_dags/example_livy.py
index 9e561c9..e8245e2 100644
--- a/airflow/providers/apache/livy/example_dags/example_livy.py
+++ b/airflow/providers/apache/livy/example_dags/example_livy.py
@@ -25,17 +25,10 @@ from airflow import DAG
from airflow.providers.apache.livy.operators.livy import LivyOperator
from airflow.utils.dates import days_ago
-args = {
- 'owner': 'airflow',
- 'email': ['airflow@example.com'],
- 'depends_on_past': False
-}
+args = {'owner': 'airflow', 'email': ['airflow@example.com'], 'depends_on_past': False}
with DAG(
- dag_id='example_livy_operator',
- default_args=args,
- schedule_interval='@daily',
- start_date=days_ago(5),
+ dag_id='example_livy_operator', default_args=args, schedule_interval='@daily', start_date=days_ago(5),
) as dag:
livy_java_task = LivyOperator(
@@ -45,9 +38,7 @@ with DAG(
file='/spark-examples.jar',
args=[10],
num_executors=1,
- conf={
- 'spark.shuffle.compress': 'false',
- },
+ conf={'spark.shuffle.compress': 'false',},
class_name='org.apache.spark.examples.SparkPi',
)
diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py
index f003087..1d614c0 100644
--- a/airflow/providers/apache/livy/hooks/livy.py
+++ b/airflow/providers/apache/livy/hooks/livy.py
@@ -34,6 +34,7 @@ class BatchState(Enum):
"""
Batch session states
"""
+
NOT_STARTED = 'not_started'
STARTING = 'starting'
RUNNING = 'running'
@@ -65,10 +66,7 @@ class LivyHook(HttpHook, LoggingMixin):
BatchState.ERROR,
}
- _def_headers = {
- 'Content-Type': 'application/json',
- 'Accept': 'application/json'
- }
+ _def_headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
def __init__(self, livy_conn_id: str = 'livy_default') -> None:
super(LivyHook, self).__init__(http_conn_id=livy_conn_id)
@@ -93,7 +91,7 @@ class LivyHook(HttpHook, LoggingMixin):
method: str = 'GET',
data: Optional[Any] = None,
headers: Optional[Dict[str, Any]] = None,
- extra_options: Optional[Dict[Any, Any]] = None
+ extra_options: Optional[Dict[Any, Any]] = None,
) -> Any:
"""
Wrapper for HttpHook, allows to change method on the same HttpHook
@@ -138,20 +136,17 @@ class LivyHook(HttpHook, LoggingMixin):
self.get_conn()
self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url)
- response = self.run_method(
- method='POST',
- endpoint='/batches',
- data=batch_submit_body
- )
+ response = self.run_method(method='POST', endpoint='/batches', data=batch_submit_body)
self.log.debug("Got response: %s", response.text)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
- raise AirflowException("Could not submit batch. Status code: {}. Message: '{}'".format(
- err.response.status_code,
- err.response.text
- ))
+ raise AirflowException(
+ "Could not submit batch. Status code: {}. Message: '{}'".format(
+ err.response.status_code, err.response.text
+ )
+ )
batch_id = self._parse_post_response(response.json())
if batch_id is None:
@@ -178,10 +173,9 @@ class LivyHook(HttpHook, LoggingMixin):
response.raise_for_status()
except requests.exceptions.HTTPError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
- raise AirflowException("Unable to fetch batch with id: {}. Message: {}".format(
- session_id,
- err.response.text
- ))
+ raise AirflowException(
+ "Unable to fetch batch with id: {}. Message: {}".format(session_id, err.response.text)
+ )
return response.json()
@@ -203,10 +197,9 @@ class LivyHook(HttpHook, LoggingMixin):
response.raise_for_status()
except requests.exceptions.HTTPError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
- raise AirflowException("Unable to fetch batch with id: {}. Message: {}".format(
- session_id,
- err.response.text
- ))
+ raise AirflowException(
+ "Unable to fetch batch with id: {}. Message: {}".format(session_id, err.response.text)
+ )
jresp = response.json()
if 'state' not in jresp:
@@ -225,19 +218,17 @@ class LivyHook(HttpHook, LoggingMixin):
self._validate_session_id(session_id)
self.log.info("Deleting batch session %d", session_id)
- response = self.run_method(
- method='DELETE',
- endpoint='/batches/{}'.format(session_id)
- )
+ response = self.run_method(method='DELETE', endpoint='/batches/{}'.format(session_id))
try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
self.log.warning("Got status code %d for session %d", err.response.status_code, session_id)
- raise AirflowException("Could not kill the batch with session id: {}. Message: {}".format(
- session_id,
- err.response.text
- ))
+ raise AirflowException(
+ "Could not kill the batch with session id: {}. Message: {}".format(
+ session_id, err.response.text
+ )
+ )
return response.json()
@@ -283,7 +274,7 @@ class LivyHook(HttpHook, LoggingMixin):
num_executors: Optional[Union[int, str]] = None,
queue: Optional[str] = None,
proxy_user: Optional[str] = None,
- conf: Optional[Dict[Any, Any]] = None
+ conf: Optional[Dict[Any, Any]] = None,
) -> Any:
"""
Build the post batch request body.
@@ -386,9 +377,11 @@ class LivyHook(HttpHook, LoggingMixin):
:return: true if valid
:rtype: bool
"""
- if vals is None or \
- not isinstance(vals, (tuple, list)) or \
- any(1 for val in vals if not isinstance(val, (str, int, float))):
+ if (
+ vals is None
+ or not isinstance(vals, (tuple, list))
+ or any(1 for val in vals if not isinstance(val, (str, int, float)))
+ ):
raise ValueError("List of strings expected")
return True
diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py
index 16be339..cbaaec2 100644
--- a/airflow/providers/apache/livy/operators/livy.py
+++ b/airflow/providers/apache/livy/operators/livy.py
@@ -74,7 +74,8 @@ class LivyOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *,
+ self,
+ *,
file: str,
class_name: Optional[str] = None,
args: Optional[Sequence[Union[str, int, float]]] = None,
@@ -93,7 +94,7 @@ class LivyOperator(BaseOperator):
proxy_user: Optional[str] = None,
livy_conn_id: str = 'livy_default',
polling_interval: int = 0,
- **kwargs: Any
+ **kwargs: Any,
) -> None:
# pylint: disable-msg=too-many-arguments
@@ -115,7 +116,7 @@ class LivyOperator(BaseOperator):
'queue': queue,
'name': name,
'conf': conf,
- 'proxy_user': proxy_user
+ 'proxy_user': proxy_user,
}
self._livy_conn_id = livy_conn_id
diff --git a/airflow/providers/apache/livy/sensors/livy.py b/airflow/providers/apache/livy/sensors/livy.py
index b9d0bc4..ba29b7f 100644
--- a/airflow/providers/apache/livy/sensors/livy.py
+++ b/airflow/providers/apache/livy/sensors/livy.py
@@ -39,10 +39,7 @@ class LivySensor(BaseSensorOperator):
@apply_defaults
def __init__(
- self, *,
- batch_id: Union[int, str],
- livy_conn_id: str = 'livy_default',
- **kwargs: Any
+ self, *, batch_id: Union[int, str], livy_conn_id: str = 'livy_default', **kwargs: Any
) -> None:
super().__init__(**kwargs)
self._livy_conn_id = livy_conn_id
diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py
index 8917f86..368135a 100644
--- a/airflow/providers/apache/pig/example_dags/example_pig.py
+++ b/airflow/providers/apache/pig/example_dags/example_pig.py
@@ -31,12 +31,7 @@ dag = DAG(
default_args=args,
schedule_interval=None,
start_date=days_ago(2),
- tags=['example']
+ tags=['example'],
)
-run_this = PigOperator(
- task_id="run_example_pig_script",
- pig="ls /;",
- pig_opts="-x local",
- dag=dag,
-)
+run_this = PigOperator(task_id="run_example_pig_script", pig="ls /;", pig_opts="-x local", dag=dag,)
diff --git a/airflow/providers/apache/pig/hooks/pig.py b/airflow/providers/apache/pig/hooks/pig.py
index 8baee6c..4152dd2 100644
--- a/airflow/providers/apache/pig/hooks/pig.py
+++ b/airflow/providers/apache/pig/hooks/pig.py
@@ -33,17 +33,14 @@ class PigCliHook(BaseHook):
"""
- def __init__(
- self,
- pig_cli_conn_id: str = "pig_cli_default") -> None:
+ def __init__(self, pig_cli_conn_id: str = "pig_cli_default") -> None:
super().__init__()
conn = self.get_connection(pig_cli_conn_id)
self.pig_properties = conn.extra_dejson.get('pig_properties', '')
self.conn = conn
self.sub_process = None
- def run_cli(self, pig: str, pig_opts: Optional[str] = None,
- verbose: bool = True) -> Any:
+ def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True) -> Any:
"""
Run an pig script using the pig cli
@@ -75,11 +72,8 @@ class PigCliHook(BaseHook):
if verbose:
self.log.info("%s", " ".join(pig_cmd))
sub_process: Any = subprocess.Popen(
- pig_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- cwd=tmp_dir,
- close_fds=True)
+ pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
+ )
self.sub_process = sub_process
stdout = ''
for line in iter(sub_process.stdout.readline, b''):
diff --git a/airflow/providers/apache/pig/operators/pig.py b/airflow/providers/apache/pig/operators/pig.py
index 3f3c578..6d0f74e 100644
--- a/airflow/providers/apache/pig/operators/pig.py
+++ b/airflow/providers/apache/pig/operators/pig.py
@@ -42,17 +42,22 @@ class PigOperator(BaseOperator):
"""
template_fields = ('pig',)
- template_ext = ('.pig', '.piglatin',)
+ template_ext = (
+ '.pig',
+ '.piglatin',
+ )
ui_color = '#f0e4ec'
@apply_defaults
def __init__(
- self, *,
- pig: str,
- pig_cli_conn_id: str = 'pig_cli_default',
- pigparams_jinja_translate: bool = False,
- pig_opts: Optional[str] = None,
- **kwargs: Any) -> None:
+ self,
+ *,
+ pig: str,
+ pig_cli_conn_id: str = 'pig_cli_default',
+ pigparams_jinja_translate: bool = False,
+ pig_opts: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
super().__init__(**kwargs)
self.pigparams_jinja_translate = pigparams_jinja_translate
@@ -63,8 +68,7 @@ class PigOperator(BaseOperator):
def prepare_template(self):
if self.pigparams_jinja_translate:
- self.pig = re.sub(
- r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig)
+ self.pig = re.sub(r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig)
def execute(self, context):
self.log.info('Executing: %s', self.pig)
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py
index 2436945..248c058 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -54,26 +54,26 @@ class PinotAdminHook(BaseHook):
:type pinot_admin_system_exit: bool
"""
- def __init__(self,
- conn_id: str = "pinot_admin_default",
- cmd_path: str = "pinot-admin.sh",
- pinot_admin_system_exit: bool = False
- ) -> None:
+ def __init__(
+ self,
+ conn_id: str = "pinot_admin_default",
+ cmd_path: str = "pinot-admin.sh",
+ pinot_admin_system_exit: bool = False,
+ ) -> None:
super().__init__()
conn = self.get_connection(conn_id)
self.host = conn.host
self.port = str(conn.port)
self.cmd_path = conn.extra_dejson.get("cmd_path", cmd_path)
- self.pinot_admin_system_exit = conn.extra_dejson.get("pinot_admin_system_exit",
- pinot_admin_system_exit)
+ self.pinot_admin_system_exit = conn.extra_dejson.get(
+ "pinot_admin_system_exit", pinot_admin_system_exit
+ )
self.conn = conn
def get_conn(self) -> Any:
return self.conn
- def add_schema(self, schema_file: str,
- with_exec: Optional[bool] = True
- ) -> Any:
+ def add_schema(self, schema_file: str, with_exec: Optional[bool] = True) -> Any:
"""
Add Pinot schema by run AddSchema command
@@ -90,9 +90,7 @@ class PinotAdminHook(BaseHook):
cmd += ["-exec"]
self.run_cli(cmd)
- def add_table(self, file_path: str,
- with_exec: Optional[bool] = True
- ) -> Any:
+ def add_table(self, file_path: str, with_exec: Optional[bool] = True) -> Any:
"""
Add Pinot table with run AddTable command
@@ -110,26 +108,27 @@ class PinotAdminHook(BaseHook):
self.run_cli(cmd)
# pylint: disable=too-many-arguments
- def create_segment(self,
- generator_config_file: Optional[str] = None,
- data_dir: Optional[str] = None,
- segment_format: Optional[str] = None,
- out_dir: Optional[str] = None,
- overwrite: Optional[str] = None,
- table_name: Optional[str] = None,
- segment_name: Optional[str] = None,
- time_column_name: Optional[str] = None,
- schema_file: Optional[str] = None,
- reader_config_file: Optional[str] = None,
- enable_star_tree_index: Optional[str] = None,
- star_tree_index_spec_file: Optional[str] = None,
- hll_size: Optional[str] = None,
- hll_columns: Optional[str] = None,
- hll_suffix: Optional[str] = None,
- num_threads: Optional[str] = None,
- post_creation_verification: Optional[str] = None,
- retry: Optional[str] = None
- ) -> Any:
+ def create_segment(
+ self,
+ generator_config_file: Optional[str] = None,
+ data_dir: Optional[str] = None,
+ segment_format: Optional[str] = None,
+ out_dir: Optional[str] = None,
+ overwrite: Optional[str] = None,
+ table_name: Optional[str] = None,
+ segment_name: Optional[str] = None,
+ time_column_name: Optional[str] = None,
+ schema_file: Optional[str] = None,
+ reader_config_file: Optional[str] = None,
+ enable_star_tree_index: Optional[str] = None,
+ star_tree_index_spec_file: Optional[str] = None,
+ hll_size: Optional[str] = None,
+ hll_columns: Optional[str] = None,
+ hll_suffix: Optional[str] = None,
+ num_threads: Optional[str] = None,
+ post_creation_verification: Optional[str] = None,
+ retry: Optional[str] = None,
+ ) -> Any:
"""
Create Pinot segment by run CreateSegment command
"""
@@ -191,8 +190,7 @@ class PinotAdminHook(BaseHook):
self.run_cli(cmd)
- def upload_segment(self, segment_dir: str, table_name: Optional[str] = None
- ) -> Any:
+ def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> Any:
"""
Upload Segment with run UploadSegment command
@@ -230,11 +228,8 @@ class PinotAdminHook(BaseHook):
self.log.info(" ".join(command))
sub_process = subprocess.Popen(
- command,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- close_fds=True,
- env=env)
+ command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, env=env
+ )
stdout = ""
if sub_process.stdout:
@@ -248,8 +243,9 @@ class PinotAdminHook(BaseHook):
# As of Pinot v0.1.0, either of "Error: ..." or "Exception caught: ..."
# is expected to be in the output messages. See:
# https://github.com/apache/incubator-pinot/blob/release-0.1.0/pinot-tools/src/main/java/org/apache/pinot/tools/admin/PinotAdministrator.java#L98-L101
- if ((self.pinot_admin_system_exit and sub_process.returncode) or
- ("Error" in stdout or "Exception" in stdout)):
+ if (self.pinot_admin_system_exit and sub_process.returncode) or (
+ "Error" in stdout or "Exception" in stdout
+ ):
raise AirflowException(stdout)
return stdout
@@ -259,6 +255,7 @@ class PinotDbApiHook(DbApiHook):
"""
Connect to pinot db (https://github.com/apache/incubator-pinot) to issue pql
"""
+
conn_name_attr = 'pinot_broker_conn_id'
default_conn_name = 'pinot_broker_default'
supports_autocommit = False
@@ -274,10 +271,9 @@ class PinotDbApiHook(DbApiHook):
host=conn.host,
port=conn.port,
path=conn.extra_dejson.get('endpoint', '/pql'),
- scheme=conn.extra_dejson.get('schema', 'http')
+ scheme=conn.extra_dejson.get('schema', 'http'),
)
- self.log.info('Get the connection to pinot '
- 'broker on %s', conn.host)
+ self.log.info('Get the connection to pinot ' 'broker on %s', conn.host)
return pinot_broker_conn
def get_uri(self) -> str:
@@ -292,12 +288,9 @@ class PinotDbApiHook(DbApiHook):
host += ':{port}'.format(port=conn.port)
conn_type = 'http' if not conn.conn_type else conn.conn_type
endpoint = conn.extra_dejson.get('endpoint', 'pql')
- return '{conn_type}://{host}/{endpoint}'.format(
- conn_type=conn_type, host=host, endpoint=endpoint)
+ return '{conn_type}://{host}/{endpoint}'.format(conn_type=conn_type, host=host, endpoint=endpoint)
- def get_records(self, sql: str,
- parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None
- ) -> Any:
+ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
"""
Executes the sql and returns a set of records.
@@ -311,9 +304,7 @@ class PinotDbApiHook(DbApiHook):
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql: str,
- parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None
- ) -> Any:
+ def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
"""
Executes the sql and returns the first resulting row.
@@ -330,9 +321,13 @@ class PinotDbApiHook(DbApiHook):
def set_autocommit(self, conn: Connection, autocommit: Any) -> Any:
raise NotImplementedError()
- def insert_rows(self, table: str, rows: str,
- target_fields: Optional[str] = None,
- commit_every: int = 1000,
- replace: bool = False,
- **kwargs: Any) -> Any:
+ def insert_rows(
+ self,
+ table: str,
+ rows: str,
+ target_fields: Optional[str] = None,
+ commit_every: int = 1000,
+ replace: bool = False,
+ **kwargs: Any,
+ ) -> Any:
raise NotImplementedError()
diff --git a/airflow/providers/apache/spark/example_dags/example_spark_dag.py b/airflow/providers/apache/spark/example_dags/example_spark_dag.py
index 5d279e0..982a773 100644
--- a/airflow/providers/apache/spark/example_dags/example_spark_dag.py
+++ b/airflow/providers/apache/spark/example_dags/example_spark_dag.py
@@ -35,12 +35,11 @@ with DAG(
default_args=args,
schedule_interval=None,
start_date=days_ago(2),
- tags=['example']
+ tags=['example'],
) as dag:
# [START howto_operator_spark_submit]
submit_job = SparkSubmitOperator(
- application="${SPARK_HOME}/examples/src/main/python/pi.py",
- task_id="submit_job"
+ application="${SPARK_HOME}/examples/src/main/python/pi.py", task_id="submit_job"
)
# [END howto_operator_spark_submit]
@@ -53,7 +52,7 @@ with DAG(
metastore_table="bar",
save_mode="overwrite",
save_format="JSON",
- task_id="jdbc_to_spark_job"
+ task_id="jdbc_to_spark_job",
)
spark_to_jdbc_job = SparkJDBCOperator(
@@ -63,14 +62,10 @@ with DAG(
jdbc_driver="org.postgresql.Driver",
metastore_table="bar",
save_mode="append",
- task_id="spark_to_jdbc_job"
+ task_id="spark_to_jdbc_job",
)
# [END howto_operator_spark_jdbc]
# [START howto_operator_spark_sql]
- sql_job = SparkSqlOperator(
- sql="SELECT * FROM bar",
- master="local",
- task_id="sql_job"
- )
+ sql_job = SparkSqlOperator(sql="SELECT * FROM bar", master="local", task_id="sql_job")
# [END howto_operator_spark_sql]
diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py
index 8ec3f49..2100b64 100644
--- a/airflow/providers/apache/spark/hooks/spark_jdbc.py
+++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py
@@ -113,38 +113,39 @@ class SparkJDBCHook(SparkSubmitHook):
"""
# pylint: disable=too-many-arguments,too-many-locals
- def __init__(self,
- spark_app_name: str = 'airflow-spark-jdbc',
- spark_conn_id: str = 'spark-default',
- spark_conf: Optional[Dict[str, Any]] = None,
- spark_py_files: Optional[str] = None,
- spark_files: Optional[str] = None,
- spark_jars: Optional[str] = None,
- num_executors: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
- verbose: bool = False,
- principal: Optional[str] = None,
- keytab: Optional[str] = None,
- cmd_type: str = 'spark_to_jdbc',
- jdbc_table: Optional[str] = None,
- jdbc_conn_id: str = 'jdbc-default',
- jdbc_driver: Optional[str] = None,
- metastore_table: Optional[str] = None,
- jdbc_truncate: bool = False,
- save_mode: Optional[str] = None,
- save_format: Optional[str] = None,
- batch_size: Optional[int] = None,
- fetch_size: Optional[int] = None,
- num_partitions: Optional[int] = None,
- partition_column: Optional[str] = None,
- lower_bound: Optional[str] = None,
- upper_bound: Optional[str] = None,
- create_table_column_types: Optional[str] = None,
- *args: Any,
- **kwargs: Any
- ):
+ def __init__(
+ self,
+ spark_app_name: str = 'airflow-spark-jdbc',
+ spark_conn_id: str = 'spark-default',
+ spark_conf: Optional[Dict[str, Any]] = None,
+ spark_py_files: Optional[str] = None,
+ spark_files: Optional[str] = None,
+ spark_jars: Optional[str] = None,
+ num_executors: Optional[int] = None,
+ executor_cores: Optional[int] = None,
+ executor_memory: Optional[str] = None,
+ driver_memory: Optional[str] = None,
+ verbose: bool = False,
+ principal: Optional[str] = None,
+ keytab: Optional[str] = None,
+ cmd_type: str = 'spark_to_jdbc',
+ jdbc_table: Optional[str] = None,
+ jdbc_conn_id: str = 'jdbc-default',
+ jdbc_driver: Optional[str] = None,
+ metastore_table: Optional[str] = None,
+ jdbc_truncate: bool = False,
+ save_mode: Optional[str] = None,
+ save_format: Optional[str] = None,
+ batch_size: Optional[int] = None,
+ fetch_size: Optional[int] = None,
+ num_partitions: Optional[int] = None,
+ partition_column: Optional[str] = None,
+ lower_bound: Optional[str] = None,
+ upper_bound: Optional[str] = None,
+ create_table_column_types: Optional[str] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
super().__init__(*args, **kwargs)
self._name = spark_app_name
self._conn_id = spark_conn_id
@@ -177,12 +178,7 @@ class SparkJDBCHook(SparkSubmitHook):
self._jdbc_connection = self._resolve_jdbc_connection()
def _resolve_jdbc_connection(self) -> Dict[str, Any]:
- conn_data = {'url': '',
- 'schema': '',
- 'conn_prefix': '',
- 'user': '',
- 'password': ''
- }
+ conn_data = {'url': '', 'schema': '', 'conn_prefix': '', 'user': '', 'password': ''}
try:
conn = self.get_connection(self._jdbc_conn_id)
if conn.port:
@@ -196,8 +192,7 @@ class SparkJDBCHook(SparkSubmitHook):
conn_data['conn_prefix'] = extra.get('conn_prefix', '')
except AirflowException:
self.log.debug(
- "Could not load jdbc connection string %s, defaulting to %s",
- self._jdbc_conn_id, ""
+ "Could not load jdbc connection string %s, defaulting to %s", self._jdbc_conn_id, ""
)
return conn_data
@@ -205,9 +200,10 @@ class SparkJDBCHook(SparkSubmitHook):
arguments = []
arguments += ["-cmdType", self._cmd_type]
if self._jdbc_connection['url']:
- arguments += ['-url', "{0}{1}/{2}".format(
- jdbc_conn['conn_prefix'], jdbc_conn['url'], jdbc_conn['schema']
- )]
+ arguments += [
+ '-url',
+ "{0}{1}/{2}".format(jdbc_conn['conn_prefix'], jdbc_conn['url'], jdbc_conn['schema']),
+ ]
if self._jdbc_connection['user']:
arguments += ['-user', self._jdbc_connection['user']]
if self._jdbc_connection['password']:
@@ -226,12 +222,16 @@ class SparkJDBCHook(SparkSubmitHook):
arguments += ['-fetchsize', str(self._fetch_size)]
if self._num_partitions:
arguments += ['-numPartitions', str(self._num_partitions)]
- if (self._partition_column and self._lower_bound and
- self._upper_bound and self._num_partitions):
+ if self._partition_column and self._lower_bound and self._upper_bound and self._num_partitions:
# these 3 parameters need to be used all together to take effect.
- arguments += ['-partitionColumn', self._partition_column,
- '-lowerBound', self._lower_bound,
- '-upperBound', self._upper_bound]
+ arguments += [
+ '-partitionColumn',
+ self._partition_column,
+ '-lowerBound',
+ self._lower_bound,
+ '-upperBound',
+ self._upper_bound,
+ ]
if self._save_mode:
arguments += ['-saveMode', self._save_mode]
if self._save_format:
@@ -244,10 +244,8 @@ class SparkJDBCHook(SparkSubmitHook):
"""
Submit Spark JDBC job
"""
- self._application_args = \
- self._build_jdbc_application_arguments(self._jdbc_connection)
- self.submit(application=os.path.dirname(os.path.abspath(__file__)) +
- "/spark_jdbc_script.py")
+ self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection)
+ self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py")
def get_conn(self) -> Any:
pass
diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
index 3a9f56a..ffc9a3e 100644
--- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
+++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
@@ -25,12 +25,14 @@ SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc"
SPARK_READ_FROM_JDBC: str = "jdbc_to_spark"
-def set_common_options(spark_source: Any,
- url: str = 'localhost:5432',
- jdbc_table: str = 'default.default',
- user: str = 'root',
- password: str = 'root',
- driver: str = 'driver') -> Any:
+def set_common_options(
+ spark_source: Any,
+ url: str = 'localhost:5432',
+ jdbc_table: str = 'default.default',
+ user: str = 'root',
+ password: str = 'root',
+ driver: str = 'driver',
+) -> Any:
"""
Get Spark source from JDBC connection
@@ -42,36 +44,36 @@ def set_common_options(spark_source: Any,
:param driver: JDBC resource driver
"""
- spark_source = spark_source \
- .format('jdbc') \
- .option('url', url) \
- .option('dbtable', jdbc_table) \
- .option('user', user) \
- .option('password', password) \
+ spark_source = (
+ spark_source.format('jdbc')
+ .option('url', url)
+ .option('dbtable', jdbc_table)
+ .option('user', user)
+ .option('password', password)
.option('driver', driver)
+ )
return spark_source
# pylint: disable=too-many-arguments
-def spark_write_to_jdbc(spark_session: SparkSession,
- url: str,
- user: str,
- password: str,
- metastore_table: str,
- jdbc_table: str,
- driver: Any,
- truncate: bool,
- save_mode: str,
- batch_size: int,
- num_partitions: int,
- create_table_column_types: str) -> None:
+def spark_write_to_jdbc(
+ spark_session: SparkSession,
+ url: str,
+ user: str,
+ password: str,
+ metastore_table: str,
+ jdbc_table: str,
+ driver: Any,
+ truncate: bool,
+ save_mode: str,
+ batch_size: int,
+ num_partitions: int,
+ create_table_column_types: str,
+) -> None:
"""
Transfer data from Spark to JDBC source
"""
- writer = spark_session \
- .table(metastore_table) \
- .write \
-
+ writer = spark_session.table(metastore_table).write
# first set common options
writer = set_common_options(writer, url, jdbc_table, user, password, driver)
@@ -85,26 +87,26 @@ def spark_write_to_jdbc(spark_session: SparkSession,
if create_table_column_types:
writer = writer.option("createTableColumnTypes", create_table_column_types)
- writer \
- .save(mode=save_mode)
+ writer.save(mode=save_mode)
# pylint: disable=too-many-arguments
-def spark_read_from_jdbc(spark_session: SparkSession,
- url: str,
- user: str,
- password: str,
- metastore_table: str,
- jdbc_table: str,
- driver: Any,
- save_mode: str,
- save_format: str,
- fetch_size: int,
- num_partitions: int,
- partition_column: str,
- lower_bound: str,
- upper_bound: str
- ) -> None:
+def spark_read_from_jdbc(
+ spark_session: SparkSession,
+ url: str,
+ user: str,
+ password: str,
+ metastore_table: str,
+ jdbc_table: str,
+ driver: Any,
+ save_mode: str,
+ save_format: str,
+ fetch_size: int,
+ num_partitions: int,
+ partition_column: str,
+ lower_bound: str,
+ upper_bound: str,
+) -> None:
"""
Transfer data from JDBC source to Spark
"""
@@ -118,15 +120,13 @@ def spark_read_from_jdbc(spark_session: SparkSession,
if num_partitions:
reader = reader.option('numPartitions', num_partitions)
if partition_column and lower_bound and upper_bound:
- reader = reader \
- .option('partitionColumn', partition_column) \
- .option('lowerBound', lower_bound) \
+ reader = (
+ reader.option('partitionColumn', partition_column)
+ .option('lowerBound', lower_bound)
.option('upperBound', upper_bound)
+ )
- reader \
- .load() \
- .write \
- .saveAsTable(metastore_table, format=save_format, mode=save_mode)
+ reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode)
def _parse_arguments(args: Optional[List[str]] = None) -> Any:
@@ -148,16 +148,12 @@ def _parse_arguments(args: Optional[List[str]] = None) -> Any:
parser.add_argument('-partitionColumn', dest='partition_column', action='store')
parser.add_argument('-lowerBound', dest='lower_bound', action='store')
parser.add_argument('-upperBound', dest='upper_bound', action='store')
- parser.add_argument('-createTableColumnTypes',
- dest='create_table_column_types', action='store')
+ parser.add_argument('-createTableColumnTypes', dest='create_table_column_types', action='store')
return parser.parse_args(args=args)
def _create_spark_session(arguments: Any) -> SparkSession:
- return SparkSession.builder \
- .appName(arguments.name) \
- .enableHiveSupport() \
- .getOrCreate()
+ return SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate()
def _run_spark(arguments: Any) -> None:
@@ -165,33 +161,37 @@ def _run_spark(arguments: Any) -> None:
spark = _create_spark_session(arguments)
if arguments.cmd_type == SPARK_WRITE_TO_JDBC:
- spark_write_to_jdbc(spark,
- arguments.url,
- arguments.user,
- arguments.password,
- arguments.metastore_table,
- arguments.jdbc_table,
- arguments.jdbc_driver,
- arguments.truncate,
- arguments.save_mode,
- arguments.batch_size,
- arguments.num_partitions,
- arguments.create_table_column_types)
+ spark_write_to_jdbc(
+ spark,
+ arguments.url,
+ arguments.user,
+ arguments.password,
+ arguments.metastore_table,
+ arguments.jdbc_table,
+ arguments.jdbc_driver,
+ arguments.truncate,
+ arguments.save_mode,
+ arguments.batch_size,
+ arguments.num_partitions,
+ arguments.create_table_column_types,
+ )
elif arguments.cmd_type == SPARK_READ_FROM_JDBC:
- spark_read_from_jdbc(spark,
- arguments.url,
- arguments.user,
- arguments.password,
- arguments.metastore_table,
- arguments.jdbc_table,
- arguments.jdbc_driver,
- arguments.save_mode,
- arguments.save_format,
- arguments.fetch_size,
- arguments.num_partitions,
- arguments.partition_column,
- arguments.lower_bound,
- arguments.upper_bound)
+ spark_read_from_jdbc(
+ spark,
+ arguments.url,
+ arguments.user,
+ arguments.password,
+ arguments.metastore_table,
+ arguments.jdbc_table,
+ arguments.jdbc_driver,
+ arguments.save_mode,
+ arguments.save_format,
+ arguments.fetch_size,
+ arguments.num_partitions,
+ arguments.partition_column,
+ arguments.lower_bound,
+ arguments.upper_bound,
+ )
if __name__ == "__main__": # pragma: no cover
diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py
index c0491dd..cceb2bc 100644
--- a/airflow/providers/apache/spark/hooks/spark_sql.py
+++ b/airflow/providers/apache/spark/hooks/spark_sql.py
@@ -57,21 +57,22 @@ class SparkSqlHook(BaseHook):
"""
# pylint: disable=too-many-arguments
- def __init__(self,
- sql: str,
- conf: Optional[str] = None,
- conn_id: str = 'spark_sql_default',
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- master: str = 'yarn',
- name: str = 'default-name',
- num_executors: Optional[int] = None,
- verbose: bool = True,
- yarn_queue: str = 'default'
- ) -> None:
+ def __init__(
+ self,
+ sql: str,
+ conf: Optional[str] = None,
+ conn_id: str = 'spark_sql_default',
+ total_executor_cores: Optional[int] = None,
+ executor_cores: Optional[int] = None,
+ executor_memory: Optional[str] = None,
+ keytab: Optional[str] = None,
+ principal: Optional[str] = None,
+ master: str = 'yarn',
+ name: str = 'default-name',
+ num_executors: Optional[int] = None,
+ verbose: bool = True,
+ yarn_queue: str = 'default',
+ ) -> None:
super().__init__()
self._sql = sql
self._conf = conf
@@ -152,10 +153,7 @@ class SparkSqlHook(BaseHook):
:type kwargs: dict
"""
spark_sql_cmd = self._prepare_command(cmd)
- self._sp = subprocess.Popen(spark_sql_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- **kwargs)
+ self._sp = subprocess.Popen(spark_sql_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs)
for line in iter(self._sp.stdout): # type: ignore
self.log.info(line)
diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py
index 0319f6d..13983e9 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -105,33 +105,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"""
# pylint: disable=too-many-arguments,too-many-locals,too-many-branches
- def __init__(self,
- conf: Optional[Dict[str, Any]] = None,
- conn_id: str = 'spark_default',
- files: Optional[str] = None,
- py_files: Optional[str] = None,
- archives: Optional[str] = None,
- driver_class_path: Optional[str] = None,
- jars: Optional[str] = None,
- java_class: Optional[str] = None,
- packages: Optional[str] = None,
- exclude_packages: Optional[str] = None,
- repositories: Optional[str] = None,
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- proxy_user: Optional[str] = None,
- name: str = 'default-name',
- num_executors: Optional[int] = None,
- status_poll_interval: int = 1,
- application_args: Optional[List[Any]] = None,
- env_vars: Optional[Dict[str, Any]] = None,
- verbose: bool = False,
- spark_binary: Optional[str] = None
- ) -> None:
+ def __init__(
+ self,
+ conf: Optional[Dict[str, Any]] = None,
+ conn_id: str = 'spark_default',
+ files: Optional[str] = None,
+ py_files: Optional[str] = None,
+ archives: Optional[str] = None,
+ driver_class_path: Optional[str] = None,
+ jars: Optional[str] = None,
+ java_class: Optional[str] = None,
+ packages: Optional[str] = None,
+ exclude_packages: Optional[str] = None,
+ repositories: Optional[str] = None,
+ total_executor_cores: Optional[int] = None,
+ executor_cores: Optional[int] = None,
+ executor_memory: Optional[str] = None,
+ driver_memory: Optional[str] = None,
+ keytab: Optional[str] = None,
+ principal: Optional[str] = None,
+ proxy_user: Optional[str] = None,
+ name: str = 'default-name',
+ num_executors: Optional[int] = None,
+ status_poll_interval: int = 1,
+ application_args: Optional[List[Any]] = None,
+ env_vars: Optional[Dict[str, Any]] = None,
+ verbose: bool = False,
+ spark_binary: Optional[str] = None,
+ ) -> None:
super().__init__()
self._conf = conf or {}
self._conn_id = conn_id
@@ -168,7 +169,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
if self._is_kubernetes and kube_client is None:
raise RuntimeError(
"{} specified by kubernetes dependencies are not installed!".format(
- self._connection['master']))
+ self._connection['master']
+ )
+ )
self._should_track_driver_status = self._resolve_should_track_driver_status()
self._driver_id: Optional[str] = None
@@ -182,17 +185,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
subsequent spark-submit status requests after the initial spark-submit request
:return: if the driver status should be tracked
"""
- return ('spark://' in self._connection['master'] and
- self._connection['deploy_mode'] == 'cluster')
+ return 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster'
def _resolve_connection(self) -> Dict[str, Any]:
# Build from connection master or default to yarn if not available
- conn_data = {'master': 'yarn',
- 'queue': None,
- 'deploy_mode': None,
- 'spark_home': None,
- 'spark_binary': self._spark_binary or "spark-submit",
- 'namespace': None}
+ conn_data = {
+ 'master': 'yarn',
+ 'queue': None,
+ 'deploy_mode': None,
+ 'spark_home': None,
+ 'spark_binary': self._spark_binary or "spark-submit",
+ 'namespace': None,
+ }
try:
# Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and
@@ -208,13 +212,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
conn_data['queue'] = extra.get('queue', None)
conn_data['deploy_mode'] = extra.get('deploy-mode', None)
conn_data['spark_home'] = extra.get('spark-home', None)
- conn_data['spark_binary'] = self._spark_binary or \
- extra.get('spark-binary', "spark-submit")
+ conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit")
conn_data['namespace'] = extra.get('namespace')
except AirflowException:
self.log.info(
- "Could not load connection string %s, defaulting to %s",
- self._conn_id, conn_data['master']
+ "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data['master']
)
if 'spark.kubernetes.namespace' in self._conf:
@@ -230,8 +232,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# the spark_home; otherwise assume that spark-submit is present in the path to
# the executing user
if self._connection['spark_home']:
- connection_cmd = [os.path.join(self._connection['spark_home'], 'bin',
- self._connection['spark_binary'])]
+ connection_cmd = [
+ os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])
+ ]
else:
connection_cmd = [self._connection['spark_binary']]
@@ -242,18 +245,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# where key contains password (case insensitive), e.g. HivePassword='abc'
connection_cmd_masked = re.sub(
r"("
- r"\S*?" # Match all non-whitespace characters before...
+ r"\S*?" # Match all non-whitespace characters before...
r"(?:secret|password)" # ...literally a "secret" or "password"
- # word (not capturing them).
- r"\S*?" # All non-whitespace characters before either...
- r"(?:=|\s+)" # ...an equal sign or whitespace characters
- # (not capturing them).
- r"(['\"]?)" # An optional single or double quote.
- r")" # This is the end of the first capturing group.
- r"(?:(?!\2\s).)*" # All characters between optional quotes
- # (matched above); if the value is quoted,
- # it may contain whitespace.
- r"(\2)", # Optional matching quote.
+ # word (not capturing them).
+ r"\S*?" # All non-whitespace characters before either...
+ r"(?:=|\s+)" # ...an equal sign or whitespace characters
+ # (not capturing them).
+ r"(['\"]?)" # An optional single or double quote.
+ r")" # This is the end of the first capturing group.
+ r"(?:(?!\2\s).)*" # All characters between optional quotes
+ # (matched above); if the value is quoted,
+ # it may contain whitespace.
+ r"(\2)", # Optional matching quote.
r'\1******\3',
' '.join(connection_cmd),
flags=re.I,
@@ -284,17 +287,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
else:
tmpl = "spark.kubernetes.driverEnv.{}={}"
for key in self._env_vars:
- connection_cmd += [
- "--conf",
- tmpl.format(key, str(self._env_vars[key]))]
+ connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))]
elif self._env_vars and self._connection['deploy_mode'] != "cluster":
self._env = self._env_vars # Do it on Popen of the process
elif self._env_vars and self._connection['deploy_mode'] == "cluster":
- raise AirflowException(
- "SparkSubmitHook env_vars is not supported in standalone-cluster mode.")
+ raise AirflowException("SparkSubmitHook env_vars is not supported in standalone-cluster mode.")
if self._is_kubernetes and self._connection['namespace']:
- connection_cmd += ["--conf", "spark.kubernetes.namespace={}".format(
- self._connection['namespace'])]
+ connection_cmd += [
+ "--conf",
+ "spark.kubernetes.namespace={}".format(self._connection['namespace']),
+ ]
if self._files:
connection_cmd += ["--files", self._files]
if self._py_files:
@@ -364,8 +366,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"--max-time",
str(curl_max_wait_time),
"{host}/v1/submissions/status/{submission_id}".format(
- host=spark_host,
- submission_id=self._driver_id)]
+ host=spark_host, submission_id=self._driver_id
+ ),
+ ]
self.log.info(connection_cmd)
# The driver id so we can poll for its status
@@ -373,8 +376,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
pass
else:
raise AirflowException(
- "Invalid status: attempted to poll driver " +
- "status but no driver id is known. Giving up.")
+ "Invalid status: attempted to poll driver "
+ + "status but no driver id is known. Giving up."
+ )
else:
@@ -388,8 +392,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
connection_cmd += ["--status", self._driver_id]
else:
raise AirflowException(
- "Invalid status: attempted to poll driver " +
- "status but no driver id is known. Giving up.")
+ "Invalid status: attempted to poll driver "
+ + "status but no driver id is known. Giving up."
+ )
self.log.debug("Poll driver status cmd: %s", connection_cmd)
@@ -410,12 +415,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
env.update(self._env)
kwargs["env"] = env
- self._submit_sp = subprocess.Popen(spark_submit_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=-1,
- universal_newlines=True,
- **kwargs)
+ self._submit_sp = subprocess.Popen(
+ spark_submit_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ bufsize=-1,
+ universal_newlines=True,
+ **kwargs,
+ )
self._process_spark_submit_log(iter(self._submit_sp.stdout)) # type: ignore
returncode = self._submit_sp.wait()
@@ -442,8 +449,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
if self._should_track_driver_status:
if self._driver_id is None:
raise AirflowException(
- "No driver id is known: something went wrong when executing " +
- "the spark submit command"
+ "No driver id is known: something went wrong when executing " + "the spark submit command"
)
# We start with the SUBMITTED status as initial status
@@ -454,8 +460,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
if self._driver_status != "FINISHED":
raise AirflowException(
- "ERROR : Driver {} badly exited with status {}"
- .format(self._driver_id, self._driver_status)
+ "ERROR : Driver {} badly exited with status {}".format(
+ self._driver_id, self._driver_status
+ )
)
def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
@@ -479,8 +486,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
match = re.search('(application[0-9_]+)', line)
if match:
self._yarn_application_id = match.groups()[0]
- self.log.info("Identified spark driver id: %s",
- self._yarn_application_id)
+ self.log.info("Identified spark driver id: %s", self._yarn_application_id)
# If we run Kubernetes cluster mode, we want to extract the driver pod id
# from the logs so we can kill the application when we stop it unexpectedly
@@ -488,8 +494,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
match = re.search(r'\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line)
if match:
self._kubernetes_driver_pod = match.groups()[0]
- self.log.info("Identified spark driver pod: %s",
- self._kubernetes_driver_pod)
+ self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod)
# Store the Spark Exit code
match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line)
@@ -520,8 +525,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# Check if the log line is about the driver status and extract the status.
if "driverState" in line:
- self._driver_status = line.split(' : ')[1] \
- .replace(',', '').replace('\"', '').strip()
+ self._driver_status = line.split(' : ')[1].replace(',', '').replace('\"', '').strip()
driver_found = True
self.log.debug("spark driver status log: %s", line)
@@ -566,8 +570,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
max_missed_job_status_reports = 10
# Keep polling as long as the driver is processing
- while self._driver_status not in ["FINISHED", "UNKNOWN",
- "KILLED", "FAILED", "ERROR"]:
+ while self._driver_status not in ["FINISHED", "UNKNOWN", "KILLED", "FAILED", "ERROR"]:
# Sleep for n seconds as we do not want to spam the cluster
time.sleep(self._status_poll_interval)
@@ -575,12 +578,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.debug("polling status of spark driver with id %s", self._driver_id)
poll_drive_status_cmd = self._build_track_driver_status_command()
- status_process: Any = subprocess.Popen(poll_drive_status_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- bufsize=-1,
- universal_newlines=True
- )
+ status_process: Any = subprocess.Popen(
+ poll_drive_status_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ bufsize=-1,
+ universal_newlines=True,
+ )
self._process_spark_status_log(iter(status_process.stdout))
returncode = status_process.wait()
@@ -590,8 +594,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
missed_job_status_reports += 1
else:
raise AirflowException(
- "Failed to poll for the driver status {} times: returncode = {}"
- .format(max_missed_job_status_reports, returncode)
+ "Failed to poll for the driver status {} times: returncode = {}".format(
+ max_missed_job_status_reports, returncode
+ )
)
def _build_spark_driver_kill_command(self) -> List[str]:
@@ -604,9 +609,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# the spark_home; otherwise assume that spark-submit is present in the path to
# the executing user
if self._connection['spark_home']:
- connection_cmd = [os.path.join(self._connection['spark_home'],
- 'bin',
- self._connection['spark_binary'])]
+ connection_cmd = [
+ os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])
+ ]
else:
connection_cmd = [self._connection['spark_binary']]
@@ -633,20 +638,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.info('Killing driver %s on cluster', self._driver_id)
kill_cmd = self._build_spark_driver_kill_command()
- driver_kill = subprocess.Popen(kill_cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
+ driver_kill = subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- self.log.info("Spark driver %s killed with return code: %s",
- self._driver_id, driver_kill.wait())
... 101273 lines suppressed ...