You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/10/09 09:33:16 UTC
[airflow] branch master updated: Strict type check for Microsoft
(#11359)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new d2754ef Strict type check for Microsoft (#11359)
d2754ef is described below
commit d2754ef76958f8df4dcb6974e2cd2c1edb17935e
Author: Satyasheel <ml...@users.noreply.github.com>
AuthorDate: Fri Oct 9 10:31:53 2020 +0100
Strict type check for Microsoft (#11359)
---
.../microsoft/azure/log/wasb_task_handler.py | 22 +++++++----
.../microsoft/azure/operators/adls_list.py | 4 +-
airflow/providers/microsoft/azure/operators/adx.py | 6 +--
.../microsoft/azure/operators/azure_batch.py | 4 +-
.../azure/operators/azure_container_instances.py | 18 ++++-----
.../microsoft/azure/operators/azure_cosmos.py | 3 +-
.../microsoft/azure/operators/wasb_delete_blob.py | 4 +-
.../microsoft/azure/secrets/azure_key_vault.py | 4 +-
.../microsoft/azure/sensors/azure_cosmos.py | 3 +-
airflow/providers/microsoft/azure/sensors/wasb.py | 8 ++--
.../microsoft/azure/transfers/azure_blob_to_gcs.py | 2 +-
.../microsoft/azure/transfers/file_to_wasb.py | 4 +-
.../microsoft/azure/transfers/local_to_adls.py | 2 +-
.../azure/transfers/oracle_to_azure_data_lake.py | 4 +-
airflow/providers/microsoft/mssql/hooks/mssql.py | 12 +++---
.../providers/microsoft/mssql/operators/mssql.py | 14 ++++---
airflow/providers/microsoft/winrm/hooks/winrm.py | 43 +++++++++++-----------
.../providers/microsoft/winrm/operators/winrm.py | 28 ++++++++++----
18 files changed, 104 insertions(+), 81 deletions(-)
diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
index 292e34d..5e3dc40 100644
--- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py
+++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py
@@ -17,6 +17,7 @@
# under the License.
import os
import shutil
+from typing import Optional, Tuple, Dict
from azure.common import AzureHttpError
from cached_property import cached_property
@@ -34,8 +35,13 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
"""
def __init__(
- self, base_log_folder, wasb_log_folder, wasb_container, filename_template, delete_local_copy
- ):
+ self,
+ base_log_folder: str,
+ wasb_log_folder: str,
+ wasb_container: str,
+ filename_template: str,
+ delete_local_copy: str,
+ ) -> None:
super().__init__(base_log_folder, filename_template)
self.wasb_container = wasb_container
self.remote_base = wasb_log_folder
@@ -63,14 +69,14 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
remote_conn_id,
)
- def set_context(self, ti):
+ def set_context(self, ti) -> None:
super().set_context(ti)
# Local location and remote location is needed to open and
# upload local log file to Wasb remote storage.
self.log_relative_path = self._render_filename(ti, ti.try_number)
self.upload_on_close = not ti.raw
- def close(self):
+ def close(self) -> None:
"""
Close and upload local log file to remote storage Wasb.
"""
@@ -99,7 +105,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
# Mark closed so we don't double write if close is called twice
self.closed = True
- def _read(self, ti, try_number, metadata=None):
+ def _read(self, ti, try_number: str, metadata: Optional[str] = None) -> Tuple[str, Dict[str, bool]]:
"""
Read logs of given task instance and try_number from Wasb remote storage.
If failed, read the log from task instance host machine.
@@ -125,7 +131,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
else:
return super()._read(ti, try_number)
- def wasb_log_exists(self, remote_log_location):
+ def wasb_log_exists(self, remote_log_location: str) -> bool:
"""
Check if remote_log_location exists in remote storage
@@ -138,7 +144,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
pass
return False
- def wasb_read(self, remote_log_location, return_error=False):
+ def wasb_read(self, remote_log_location: str, return_error: bool = False):
"""
Returns the log found at the remote_log_location. Returns '' if no
logs are found or there is an error.
@@ -158,7 +164,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
if return_error:
return msg
- def wasb_write(self, log, remote_log_location, append=True):
+ def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> None:
"""
Writes the log to the remote_log_location. Fails silently if no hook
was created.
diff --git a/airflow/providers/microsoft/azure/operators/adls_list.py b/airflow/providers/microsoft/azure/operators/adls_list.py
index ad97557..b42f29f 100644
--- a/airflow/providers/microsoft/azure/operators/adls_list.py
+++ b/airflow/providers/microsoft/azure/operators/adls_list.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Sequence
+from typing import Sequence
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
@@ -58,7 +58,7 @@ class AzureDataLakeStorageListOperator(BaseOperator):
self.path = path
self.azure_data_lake_conn_id = azure_data_lake_conn_id
- def execute(self, context: Dict[Any, Any]) -> List:
+ def execute(self, context: dict) -> list:
hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
diff --git a/airflow/providers/microsoft/azure/operators/adx.py b/airflow/providers/microsoft/azure/operators/adx.py
index db1e485..e5a8c46 100644
--- a/airflow/providers/microsoft/azure/operators/adx.py
+++ b/airflow/providers/microsoft/azure/operators/adx.py
@@ -18,7 +18,7 @@
#
"""This module contains Azure Data Explorer operators"""
-from typing import Any, Dict, Optional
+from typing import Optional
from azure.kusto.data._models import KustoResultTable
@@ -52,7 +52,7 @@ class AzureDataExplorerQueryOperator(BaseOperator):
*,
query: str,
database: str,
- options: Optional[Dict] = None,
+ options: Optional[dict] = None,
azure_data_explorer_conn_id: str = 'azure_data_explorer_default',
**kwargs,
) -> None:
@@ -66,7 +66,7 @@ class AzureDataExplorerQueryOperator(BaseOperator):
"""Returns new instance of AzureDataExplorerHook"""
return AzureDataExplorerHook(self.azure_data_explorer_conn_id)
- def execute(self, context: Dict[Any, Any]) -> KustoResultTable:
+ def execute(self, context: dict) -> KustoResultTable:
"""
Run KQL Query on Azure Data Explorer (Kusto).
Returns `PrimaryResult` of Query v2 HTTP response contents
diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py
index 433aa08..762547a 100644
--- a/airflow/providers/microsoft/azure/operators/azure_batch.py
+++ b/airflow/providers/microsoft/azure/operators/azure_batch.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Any, Dict, List, Optional
+from typing import Any, List, Optional
from azure.batch import models as batch_models
@@ -266,7 +266,7 @@ class AzureBatchOperator(BaseOperator):
"Some required parameters are missing.Please you must set " "all the required parameters. "
)
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
self._check_inputs()
self.hook.connection.config.retry_policy = self.batch_max_retries
diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
index fd11c41..b0ff593 100644
--- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py
+++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py
@@ -19,7 +19,7 @@
import re
from collections import namedtuple
from time import sleep
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, List, Optional, Sequence, Union, Dict
from azure.mgmt.containerinstance.models import (
Container,
@@ -44,9 +44,9 @@ Volume = namedtuple(
)
-DEFAULT_ENVIRONMENT_VARIABLES = {} # type: Dict[str, str]
-DEFAULT_SECURED_VARIABLES = [] # type: Sequence[str]
-DEFAULT_VOLUMES = [] # type: Sequence[Volume]
+DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {}
+DEFAULT_SECURED_VARIABLES: Sequence[str] = []
+DEFAULT_VOLUMES: Sequence[Volume] = []
DEFAULT_MEMORY_IN_GB = 2.0
DEFAULT_CPU = 1.0
@@ -136,9 +136,9 @@ class AzureContainerInstancesOperator(BaseOperator):
name: str,
image: str,
region: str,
- environment_variables: Optional[Dict[Any, Any]] = None,
+ environment_variables: Optional[dict] = None,
secured_variables: Optional[str] = None,
- volumes: Optional[List[Any]] = None,
+ volumes: Optional[list] = None,
memory_in_gb: Optional[Any] = None,
cpu: Optional[Any] = None,
gpu: Optional[Any] = None,
@@ -168,7 +168,7 @@ class AzureContainerInstancesOperator(BaseOperator):
self._ci_hook: Any = None
self.tags = tags
- def execute(self, context: Dict[Any, Any]) -> int:
+ def execute(self, context: dict) -> int:
# Check name again in case it was templated.
self._check_name(self.name)
@@ -181,7 +181,7 @@ class AzureContainerInstancesOperator(BaseOperator):
if self.registry_conn_id:
registry_hook = AzureContainerRegistryHook(self.registry_conn_id)
- image_registry_credentials: Optional[List[Any]] = [
+ image_registry_credentials: Optional[list] = [
registry_hook.connection,
]
else:
@@ -327,7 +327,7 @@ class AzureContainerInstancesOperator(BaseOperator):
sleep(1)
- def _log_last(self, logs: Optional[List[Any]], last_line_logged: Any) -> Optional[Any]:
+ def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]:
if logs:
# determine the last line which was logged before
last_line_index = 0
diff --git a/airflow/providers/microsoft/azure/operators/azure_cosmos.py b/airflow/providers/microsoft/azure/operators/azure_cosmos.py
index 23d5fee..df22c96 100644
--- a/airflow/providers/microsoft/azure/operators/azure_cosmos.py
+++ b/airflow/providers/microsoft/azure/operators/azure_cosmos.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
@@ -56,7 +55,7 @@ class AzureCosmosInsertDocumentOperator(BaseOperator):
self.document = document
self.azure_cosmos_conn_id = azure_cosmos_conn_id
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
# Create the hook
hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id)
diff --git a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py
index 5e5d6f2..be4f3cf 100644
--- a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py
+++ b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Any, Dict
+from typing import Any
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
@@ -66,7 +66,7 @@ class WasbDeleteBlobOperator(BaseOperator):
self.is_prefix = is_prefix
self.ignore_if_missing = ignore_if_missing
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
self.log.info('Deleting blob: %s\nin wasb://%s', self.blob_name, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
diff --git a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py
index 34ccaf5..9d98959 100644
--- a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py
+++ b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py
@@ -62,7 +62,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
vault_url: str = '',
sep: str = '-',
**kwargs,
- ):
+ ) -> None:
super().__init__()
self.vault_url = vault_url
self.connections_prefix = connections_prefix.rstrip(sep)
@@ -72,7 +72,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin):
self.kwargs = kwargs
@cached_property
- def client(self):
+ def client(self) -> SecretClient:
"""
Create a Azure Key Vault client.
"""
diff --git a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py
index 1b7eab2..f833ad0 100644
--- a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py
+++ b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -61,7 +60,7 @@ class AzureCosmosDocumentSensor(BaseSensorOperator):
self.collection_name = collection_name
self.document_id = document_id
- def poke(self, context: Dict[Any, Any]) -> bool:
+ def poke(self, context: dict) -> bool:
self.log.info("*** Intering poke")
hook = AzureCosmosDBHook(self.azure_cosmos_conn_id)
return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None
diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py
index 57d016b..0685059 100644
--- a/airflow/providers/microsoft/azure/sensors/wasb.py
+++ b/airflow/providers/microsoft/azure/sensors/wasb.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Any, Dict, Optional
+from typing import Optional
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -49,7 +49,7 @@ class WasbBlobSensor(BaseSensorOperator):
wasb_conn_id: str = 'wasb_default',
check_options: Optional[dict] = None,
**kwargs,
- ):
+ ) -> None:
super().__init__(**kwargs)
if check_options is None:
check_options = {}
@@ -58,7 +58,7 @@ class WasbBlobSensor(BaseSensorOperator):
self.blob_name = blob_name
self.check_options = check_options
- def poke(self, context: Dict[Any, Any]):
+ def poke(self, context: dict):
self.log.info('Poking for blob: %s\nin wasb://%s', self.blob_name, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options)
@@ -99,7 +99,7 @@ class WasbPrefixSensor(BaseSensorOperator):
self.prefix = prefix
self.check_options = check_options
- def poke(self, context: Dict[Any, Any]) -> bool:
+ def poke(self, context: dict) -> bool:
self.log.info('Poking for prefix: %s in wasb://%s', self.prefix, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
return hook.check_for_prefix(self.container_name, self.prefix, **self.check_options)
diff --git a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
index 1f407dd..a33a922 100644
--- a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
+++ b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py
@@ -105,7 +105,7 @@ class AzureBlobStorageToGCSOperator(BaseOperator):
"filename",
)
- def execute(self, context):
+ def execute(self, context: dict) -> str:
azure_hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
diff --git a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py
index 0fb08b7..c099faa 100644
--- a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py
+++ b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Any, Dict, Optional
+from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
@@ -62,7 +62,7 @@ class FileToWasbOperator(BaseOperator):
self.wasb_conn_id = wasb_conn_id
self.load_options = load_options
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
"""Upload a file to Azure Blob Storage."""
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)
self.log.info(
diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py
index 98b2749..755a171 100644
--- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py
+++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py
@@ -85,7 +85,7 @@ class LocalToAzureDataLakeStorageOperator(BaseOperator):
self.extra_upload_options = extra_upload_options
self.azure_data_lake_conn_id = azure_data_lake_conn_id
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
if '**' in self.local_path:
raise AirflowException("Recursive glob patterns using `**` are not supported")
if not self.extra_upload_options:
diff --git a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py
index 153173a..5071dbf 100644
--- a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py
+++ b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py
@@ -18,7 +18,7 @@
import os
from tempfile import TemporaryDirectory
-from typing import Any, Dict, Optional, Union
+from typing import Any, Optional, Union
import unicodecsv as csv
@@ -103,7 +103,7 @@ class OracleToAzureDataLakeOperator(BaseOperator):
csv_writer.writerows(cursor)
csvfile.flush()
- def execute(self, context: Dict[Any, Any]) -> None:
+ def execute(self, context: dict) -> None:
oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
azure_data_lake_hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py
index 4bee8ab..24331707 100644
--- a/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -54,7 +54,7 @@ class MsSqlHook(DbApiHook):
default_conn_name = 'mssql_default'
supports_autocommit = True
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs) -> None:
warnings.warn(
(
"This class is deprecated and will be removed in Airflow 2.0.\n"
@@ -67,11 +67,13 @@ class MsSqlHook(DbApiHook):
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
- def get_conn(self):
+ def get_conn(self) -> pymssql.connect:
"""
Returns a mssql connection object
"""
- conn = self.get_connection(self.mssql_conn_id) # pylint: disable=no-member
+ conn = self.get_connection(
+ self.mssql_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
+ )
# pylint: disable=c-extension-no-member
conn = pymssql.connect(
server=conn.host,
@@ -82,8 +84,8 @@ class MsSqlHook(DbApiHook):
)
return conn
- def set_autocommit(self, conn, autocommit):
+ def set_autocommit(self, conn: pymssql.connect, autocommit: bool) -> None:
conn.autocommit(autocommit)
- def get_autocommit(self, conn):
+ def get_autocommit(self, conn: pymssql.connect):
return conn.autocommit_state
diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py
index 25d6815..2341b75 100644
--- a/airflow/providers/microsoft/mssql/operators/mssql.py
+++ b/airflow/providers/microsoft/mssql/operators/mssql.py
@@ -68,9 +68,9 @@ class MsSqlOperator(BaseOperator):
self.parameters = parameters
self.autocommit = autocommit
self.database = database
- self._hook = None
+ self._hook: Optional[Union[MsSqlHook, OdbcHook]] = None
- def get_hook(self):
+ def get_hook(self) -> Optional[Union[MsSqlHook, OdbcHook]]:
"""
Will retrieve hook as determined by Connection.
@@ -81,13 +81,15 @@ class MsSqlOperator(BaseOperator):
if not self._hook:
conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id)
try:
- self._hook: Union[MsSqlHook, OdbcHook] = conn.get_hook()
- self._hook.schema = self.database
+ self._hook = conn.get_hook()
+ self._hook.schema = self.database # type: ignore[union-attr]
except AirflowException:
self._hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id, schema=self.database)
return self._hook
- def execute(self, context):
+ def execute(self, context: dict) -> None:
self.log.info('Executing: %s', self.sql)
hook = self.get_hook()
- hook.run(sql=self.sql, autocommit=self.autocommit, parameters=self.parameters)
+ hook.run( # type: ignore[union-attr]
+ sql=self.sql, autocommit=self.autocommit, parameters=self.parameters
+ )
diff --git a/airflow/providers/microsoft/winrm/hooks/winrm.py b/airflow/providers/microsoft/winrm/hooks/winrm.py
index ad6e5ca..4adcd28 100644
--- a/airflow/providers/microsoft/winrm/hooks/winrm.py
+++ b/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -18,6 +18,7 @@
#
"""Hook for winrm remote execution."""
import getpass
+from typing import Optional
from winrm.protocol import Protocol
@@ -90,27 +91,27 @@ class WinRMHook(BaseHook):
def __init__(
self,
- ssh_conn_id=None,
- endpoint=None,
- remote_host=None,
- remote_port=5985,
- transport='plaintext',
- username=None,
- password=None,
- service='HTTP',
- keytab=None,
- ca_trust_path=None,
- cert_pem=None,
- cert_key_pem=None,
- server_cert_validation='validate',
- kerberos_delegation=False,
- read_timeout_sec=30,
- operation_timeout_sec=20,
- kerberos_hostname_override=None,
- message_encryption='auto',
- credssp_disable_tlsv1_2=False,
- send_cbt=True,
- ):
+ ssh_conn_id: Optional[str] = None,
+ endpoint: Optional[str] = None,
+ remote_host: Optional[str] = None,
+ remote_port: int = 5985,
+ transport: str = 'plaintext',
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ service: str = 'HTTP',
+ keytab: Optional[str] = None,
+ ca_trust_path: Optional[str] = None,
+ cert_pem: Optional[str] = None,
+ cert_key_pem: Optional[str] = None,
+ server_cert_validation: str = 'validate',
+ kerberos_delegation: bool = False,
+ read_timeout_sec: int = 30,
+ operation_timeout_sec: int = 20,
+ kerberos_hostname_override: Optional[str] = None,
+ message_encryption: Optional[str] = 'auto',
+ credssp_disable_tlsv1_2: bool = False,
+ send_cbt: bool = True,
+ ) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
self.endpoint = endpoint
diff --git a/airflow/providers/microsoft/winrm/operators/winrm.py b/airflow/providers/microsoft/winrm/operators/winrm.py
index a0c2c76..8e4b507 100644
--- a/airflow/providers/microsoft/winrm/operators/winrm.py
+++ b/airflow/providers/microsoft/winrm/operators/winrm.py
@@ -18,6 +18,7 @@
import logging
from base64 import b64encode
+from typing import Optional, Union
from winrm.exceptions import WinRMOperationTimeoutError
@@ -53,8 +54,15 @@ class WinRMOperator(BaseOperator):
@apply_defaults
def __init__(
- self, *, winrm_hook=None, ssh_conn_id=None, remote_host=None, command=None, timeout=10, **kwargs
- ):
+ self,
+ *,
+ winrm_hook: Optional[WinRMHook] = None,
+ ssh_conn_id: Optional[str] = None,
+ remote_host: Optional[str] = None,
+ command: Optional[str] = None,
+ timeout: int = 10,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.winrm_hook = winrm_hook
self.ssh_conn_id = ssh_conn_id
@@ -62,7 +70,7 @@ class WinRMOperator(BaseOperator):
self.command = command
self.timeout = timeout
- def execute(self, context):
+ def execute(self, context: dict) -> Union[list, str]:
if self.ssh_conn_id and not self.winrm_hook:
self.log.info("Hook not found, creating...")
self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)
@@ -81,7 +89,9 @@ class WinRMOperator(BaseOperator):
# pylint: disable=too-many-nested-blocks
try:
self.log.info("Running command: '%s'...", self.command)
- command_id = self.winrm_hook.winrm_protocol.run_command(winrm_client, self.command)
+ command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined]
+ winrm_client, self.command
+ )
# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
stdout_buffer = []
@@ -95,7 +105,9 @@ class WinRMOperator(BaseOperator):
stderr,
return_code,
command_done,
- ) = self.winrm_hook.winrm_protocol._raw_get_command_output(winrm_client, command_id)
+ ) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined]
+ winrm_client, command_id
+ )
# Only buffer stdout if we need to so that we minimize memory usage.
if self.do_xcom_push:
@@ -111,8 +123,10 @@ class WinRMOperator(BaseOperator):
# long-running process, just silently retry
pass
- self.winrm_hook.winrm_protocol.cleanup_command(winrm_client, command_id)
- self.winrm_hook.winrm_protocol.close_shell(winrm_client)
+ self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined]
+ winrm_client, command_id
+ )
+ self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined]
except Exception as e:
raise AirflowException("WinRM operator error: {0}".format(str(e)))