This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new d2754ef Strict type check for Microsoft (#11359) d2754ef is described below commit d2754ef76958f8df4dcb6974e2cd2c1edb17935e Author: Satyasheel <mlgr...@users.noreply.github.com> AuthorDate: Fri Oct 9 10:31:53 2020 +0100 Strict type check for Microsoft (#11359) --- .../microsoft/azure/log/wasb_task_handler.py | 22 +++++++---- .../microsoft/azure/operators/adls_list.py | 4 +- airflow/providers/microsoft/azure/operators/adx.py | 6 +-- .../microsoft/azure/operators/azure_batch.py | 4 +- .../azure/operators/azure_container_instances.py | 18 ++++----- .../microsoft/azure/operators/azure_cosmos.py | 3 +- .../microsoft/azure/operators/wasb_delete_blob.py | 4 +- .../microsoft/azure/secrets/azure_key_vault.py | 4 +- .../microsoft/azure/sensors/azure_cosmos.py | 3 +- airflow/providers/microsoft/azure/sensors/wasb.py | 8 ++-- .../microsoft/azure/transfers/azure_blob_to_gcs.py | 2 +- .../microsoft/azure/transfers/file_to_wasb.py | 4 +- .../microsoft/azure/transfers/local_to_adls.py | 2 +- .../azure/transfers/oracle_to_azure_data_lake.py | 4 +- airflow/providers/microsoft/mssql/hooks/mssql.py | 12 +++--- .../providers/microsoft/mssql/operators/mssql.py | 14 ++++--- airflow/providers/microsoft/winrm/hooks/winrm.py | 43 +++++++++++----------- .../providers/microsoft/winrm/operators/winrm.py | 28 ++++++++++---- 18 files changed, 104 insertions(+), 81 deletions(-) diff --git a/airflow/providers/microsoft/azure/log/wasb_task_handler.py b/airflow/providers/microsoft/azure/log/wasb_task_handler.py index 292e34d..5e3dc40 100644 --- a/airflow/providers/microsoft/azure/log/wasb_task_handler.py +++ b/airflow/providers/microsoft/azure/log/wasb_task_handler.py @@ -17,6 +17,7 @@ # under the License. import os import shutil +from typing import Optional, Tuple, Dict from azure.common import AzureHttpError from cached_property import cached_property @@ -34,8 +35,13 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): """ def __init__( - self, base_log_folder, wasb_log_folder, wasb_container, filename_template, delete_local_copy - ): + self, + base_log_folder: str, + wasb_log_folder: str, + wasb_container: str, + filename_template: str, + delete_local_copy: str, + ) -> None: super().__init__(base_log_folder, filename_template) self.wasb_container = wasb_container self.remote_base = wasb_log_folder @@ -63,14 +69,14 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): remote_conn_id, ) - def set_context(self, ti): + def set_context(self, ti) -> None: super().set_context(ti) # Local location and remote location is needed to open and # upload local log file to Wasb remote storage. self.log_relative_path = self._render_filename(ti, ti.try_number) self.upload_on_close = not ti.raw - def close(self): + def close(self) -> None: """ Close and upload local log file to remote storage Wasb. """ @@ -99,7 +105,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): # Mark closed so we don't double write if close is called twice self.closed = True - def _read(self, ti, try_number, metadata=None): + def _read(self, ti, try_number: str, metadata: Optional[str] = None) -> Tuple[str, Dict[str, bool]]: """ Read logs of given task instance and try_number from Wasb remote storage. If failed, read the log from task instance host machine. @@ -125,7 +131,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): else: return super()._read(ti, try_number) - def wasb_log_exists(self, remote_log_location): + def wasb_log_exists(self, remote_log_location: str) -> bool: """ Check if remote_log_location exists in remote storage @@ -138,7 +144,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): pass return False - def wasb_read(self, remote_log_location, return_error=False): + def wasb_read(self, remote_log_location: str, return_error: bool = False): """ Returns the log found at the remote_log_location. Returns '' if no logs are found or there is an error. @@ -158,7 +164,7 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin): if return_error: return msg - def wasb_write(self, log, remote_log_location, append=True): + def wasb_write(self, log: str, remote_log_location: str, append: bool = True) -> None: """ Writes the log to the remote_log_location. Fails silently if no hook was created. diff --git a/airflow/providers/microsoft/azure/operators/adls_list.py b/airflow/providers/microsoft/azure/operators/adls_list.py index ad97557..b42f29f 100644 --- a/airflow/providers/microsoft/azure/operators/adls_list.py +++ b/airflow/providers/microsoft/azure/operators/adls_list.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Sequence +from typing import Sequence from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook @@ -58,7 +58,7 @@ class AzureDataLakeStorageListOperator(BaseOperator): self.path = path self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, context: Dict[Any, Any]) -> List: + def execute(self, context: dict) -> list: hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) diff --git a/airflow/providers/microsoft/azure/operators/adx.py b/airflow/providers/microsoft/azure/operators/adx.py index db1e485..e5a8c46 100644 --- a/airflow/providers/microsoft/azure/operators/adx.py +++ b/airflow/providers/microsoft/azure/operators/adx.py @@ -18,7 +18,7 @@ # """This module contains Azure Data Explorer operators""" -from typing import Any, Dict, Optional +from typing import Optional from azure.kusto.data._models import KustoResultTable @@ -52,7 +52,7 @@ class AzureDataExplorerQueryOperator(BaseOperator): *, query: str, database: str, - options: Optional[Dict] = None, + options: Optional[dict] = None, azure_data_explorer_conn_id: str = 'azure_data_explorer_default', **kwargs, ) -> None: @@ -66,7 +66,7 @@ class AzureDataExplorerQueryOperator(BaseOperator): """Returns new instance of AzureDataExplorerHook""" return AzureDataExplorerHook(self.azure_data_explorer_conn_id) - def execute(self, context: Dict[Any, Any]) -> KustoResultTable: + def execute(self, context: dict) -> KustoResultTable: """ Run KQL Query on Azure Data Explorer (Kusto). Returns `PrimaryResult` of Query v2 HTTP response contents diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py index 433aa08..762547a 100644 --- a/airflow/providers/microsoft/azure/operators/azure_batch.py +++ b/airflow/providers/microsoft/azure/operators/azure_batch.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from azure.batch import models as batch_models @@ -266,7 +266,7 @@ class AzureBatchOperator(BaseOperator): "Some required parameters are missing.Please you must set " "all the required parameters. " ) - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: self._check_inputs() self.hook.connection.config.retry_policy = self.batch_max_retries diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py index fd11c41..b0ff593 100644 --- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py +++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py @@ -19,7 +19,7 @@ import re from collections import namedtuple from time import sleep -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union, Dict from azure.mgmt.containerinstance.models import ( Container, @@ -44,9 +44,9 @@ Volume = namedtuple( ) -DEFAULT_ENVIRONMENT_VARIABLES = {} # type: Dict[str, str] -DEFAULT_SECURED_VARIABLES = [] # type: Sequence[str] -DEFAULT_VOLUMES = [] # type: Sequence[Volume] +DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {} +DEFAULT_SECURED_VARIABLES: Sequence[str] = [] +DEFAULT_VOLUMES: Sequence[Volume] = [] DEFAULT_MEMORY_IN_GB = 2.0 DEFAULT_CPU = 1.0 @@ -136,9 +136,9 @@ class AzureContainerInstancesOperator(BaseOperator): name: str, image: str, region: str, - environment_variables: Optional[Dict[Any, Any]] = None, + environment_variables: Optional[dict] = None, secured_variables: Optional[str] = None, - volumes: Optional[List[Any]] = None, + volumes: Optional[list] = None, memory_in_gb: Optional[Any] = None, cpu: Optional[Any] = None, gpu: Optional[Any] = None, @@ -168,7 +168,7 @@ class AzureContainerInstancesOperator(BaseOperator): self._ci_hook: Any = None self.tags = tags - def execute(self, context: Dict[Any, Any]) -> int: + def execute(self, context: dict) -> int: # Check name again in case it was templated. self._check_name(self.name) @@ -181,7 +181,7 @@ class AzureContainerInstancesOperator(BaseOperator): if self.registry_conn_id: registry_hook = AzureContainerRegistryHook(self.registry_conn_id) - image_registry_credentials: Optional[List[Any]] = [ + image_registry_credentials: Optional[list] = [ registry_hook.connection, ] else: @@ -327,7 +327,7 @@ class AzureContainerInstancesOperator(BaseOperator): sleep(1) - def _log_last(self, logs: Optional[List[Any]], last_line_logged: Any) -> Optional[Any]: + def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]: if logs: # determine the last line which was logged before last_line_index = 0 diff --git a/airflow/providers/microsoft/azure/operators/azure_cosmos.py b/airflow/providers/microsoft/azure/operators/azure_cosmos.py index 23d5fee..df22c96 100644 --- a/airflow/providers/microsoft/azure/operators/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/operators/azure_cosmos.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook @@ -56,7 +55,7 @@ class AzureCosmosInsertDocumentOperator(BaseOperator): self.document = document self.azure_cosmos_conn_id = azure_cosmos_conn_id - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: # Create the hook hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) diff --git a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py index 5e5d6f2..be4f3cf 100644 --- a/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py +++ b/airflow/providers/microsoft/azure/operators/wasb_delete_blob.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Any, Dict +from typing import Any from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.wasb import WasbHook @@ -66,7 +66,7 @@ class WasbDeleteBlobOperator(BaseOperator): self.is_prefix = is_prefix self.ignore_if_missing = ignore_if_missing - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: self.log.info('Deleting blob: %s\nin wasb://%s', self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) diff --git a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py index 34ccaf5..9d98959 100644 --- a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py +++ b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py @@ -62,7 +62,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): vault_url: str = '', sep: str = '-', **kwargs, - ): + ) -> None: super().__init__() self.vault_url = vault_url self.connections_prefix = connections_prefix.rstrip(sep) @@ -72,7 +72,7 @@ class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): self.kwargs = kwargs @cached_property - def client(self): + def client(self) -> SecretClient: """ Create a Azure Key Vault client. """ diff --git a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py index 1b7eab2..f833ad0 100644 --- a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py @@ -15,7 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -61,7 +60,7 @@ class AzureCosmosDocumentSensor(BaseSensorOperator): self.collection_name = collection_name self.document_id = document_id - def poke(self, context: Dict[Any, Any]) -> bool: + def poke(self, context: dict) -> bool: self.log.info("*** Intering poke") hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None diff --git a/airflow/providers/microsoft/azure/sensors/wasb.py b/airflow/providers/microsoft/azure/sensors/wasb.py index 57d016b..0685059 100644 --- a/airflow/providers/microsoft/azure/sensors/wasb.py +++ b/airflow/providers/microsoft/azure/sensors/wasb.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Any, Dict, Optional +from typing import Optional from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -49,7 +49,7 @@ class WasbBlobSensor(BaseSensorOperator): wasb_conn_id: str = 'wasb_default', check_options: Optional[dict] = None, **kwargs, - ): + ) -> None: super().__init__(**kwargs) if check_options is None: check_options = {} @@ -58,7 +58,7 @@ class WasbBlobSensor(BaseSensorOperator): self.blob_name = blob_name self.check_options = check_options - def poke(self, context: Dict[Any, Any]): + def poke(self, context: dict): self.log.info('Poking for blob: %s\nin wasb://%s', self.blob_name, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_blob(self.container_name, self.blob_name, **self.check_options) @@ -99,7 +99,7 @@ class WasbPrefixSensor(BaseSensorOperator): self.prefix = prefix self.check_options = check_options - def poke(self, context: Dict[Any, Any]) -> bool: + def poke(self, context: dict) -> bool: self.log.info('Poking for prefix: %s in wasb://%s', self.prefix, self.container_name) hook = WasbHook(wasb_conn_id=self.wasb_conn_id) return hook.check_for_prefix(self.container_name, self.prefix, **self.check_options) diff --git a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py index 1f407dd..a33a922 100644 --- a/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py +++ b/airflow/providers/microsoft/azure/transfers/azure_blob_to_gcs.py @@ -105,7 +105,7 @@ class AzureBlobStorageToGCSOperator(BaseOperator): "filename", ) - def execute(self, context): + def execute(self, context: dict) -> str: azure_hook = WasbHook(wasb_conn_id=self.wasb_conn_id) gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, diff --git a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py index 0fb08b7..c099faa 100644 --- a/airflow/providers/microsoft/azure/transfers/file_to_wasb.py +++ b/airflow/providers/microsoft/azure/transfers/file_to_wasb.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # -from typing import Any, Dict, Optional +from typing import Optional from airflow.models import BaseOperator from airflow.providers.microsoft.azure.hooks.wasb import WasbHook @@ -62,7 +62,7 @@ class FileToWasbOperator(BaseOperator): self.wasb_conn_id = wasb_conn_id self.load_options = load_options - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: """Upload a file to Azure Blob Storage.""" hook = WasbHook(wasb_conn_id=self.wasb_conn_id) self.log.info( diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py index 98b2749..755a171 100644 --- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py +++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py @@ -85,7 +85,7 @@ class LocalToAzureDataLakeStorageOperator(BaseOperator): self.extra_upload_options = extra_upload_options self.azure_data_lake_conn_id = azure_data_lake_conn_id - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: if '**' in self.local_path: raise AirflowException("Recursive glob patterns using `**` are not supported") if not self.extra_upload_options: diff --git a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py index 153173a..5071dbf 100644 --- a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +++ b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py @@ -18,7 +18,7 @@ import os from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import unicodecsv as csv @@ -103,7 +103,7 @@ class OracleToAzureDataLakeOperator(BaseOperator): csv_writer.writerows(cursor) csvfile.flush() - def execute(self, context: Dict[Any, Any]) -> None: + def execute(self, context: dict) -> None: oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) azure_data_lake_hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id) diff --git a/airflow/providers/microsoft/mssql/hooks/mssql.py b/airflow/providers/microsoft/mssql/hooks/mssql.py index 4bee8ab..24331707 100644 --- a/airflow/providers/microsoft/mssql/hooks/mssql.py +++ b/airflow/providers/microsoft/mssql/hooks/mssql.py @@ -54,7 +54,7 @@ class MsSqlHook(DbApiHook): default_conn_name = 'mssql_default' supports_autocommit = True - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: warnings.warn( ( "This class is deprecated and will be removed in Airflow 2.0.\n" @@ -67,11 +67,13 @@ class MsSqlHook(DbApiHook): super().__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None) - def get_conn(self): + def get_conn(self) -> pymssql.connect: """ Returns a mssql connection object """ - conn = self.get_connection(self.mssql_conn_id) # pylint: disable=no-member + conn = self.get_connection( + self.mssql_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) # pylint: disable=c-extension-no-member conn = pymssql.connect( server=conn.host, @@ -82,8 +84,8 @@ class MsSqlHook(DbApiHook): ) return conn - def set_autocommit(self, conn, autocommit): + def set_autocommit(self, conn: pymssql.connect, autocommit: bool) -> None: conn.autocommit(autocommit) - def get_autocommit(self, conn): + def get_autocommit(self, conn: pymssql.connect): return conn.autocommit_state diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py index 25d6815..2341b75 100644 --- a/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/airflow/providers/microsoft/mssql/operators/mssql.py @@ -68,9 +68,9 @@ class MsSqlOperator(BaseOperator): self.parameters = parameters self.autocommit = autocommit self.database = database - self._hook = None + self._hook: Optional[Union[MsSqlHook, OdbcHook]] = None - def get_hook(self): + def get_hook(self) -> Optional[Union[MsSqlHook, OdbcHook]]: """ Will retrieve hook as determined by Connection. @@ -81,13 +81,15 @@ class MsSqlOperator(BaseOperator): if not self._hook: conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id) try: - self._hook: Union[MsSqlHook, OdbcHook] = conn.get_hook() - self._hook.schema = self.database + self._hook = conn.get_hook() + self._hook.schema = self.database # type: ignore[union-attr] except AirflowException: self._hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id, schema=self.database) return self._hook - def execute(self, context): + def execute(self, context: dict) -> None: self.log.info('Executing: %s', self.sql) hook = self.get_hook() - hook.run(sql=self.sql, autocommit=self.autocommit, parameters=self.parameters) + hook.run( # type: ignore[union-attr] + sql=self.sql, autocommit=self.autocommit, parameters=self.parameters + ) diff --git a/airflow/providers/microsoft/winrm/hooks/winrm.py b/airflow/providers/microsoft/winrm/hooks/winrm.py index ad6e5ca..4adcd28 100644 --- a/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -18,6 +18,7 @@ # """Hook for winrm remote execution.""" import getpass +from typing import Optional from winrm.protocol import Protocol @@ -90,27 +91,27 @@ class WinRMHook(BaseHook): def __init__( self, - ssh_conn_id=None, - endpoint=None, - remote_host=None, - remote_port=5985, - transport='plaintext', - username=None, - password=None, - service='HTTP', - keytab=None, - ca_trust_path=None, - cert_pem=None, - cert_key_pem=None, - server_cert_validation='validate', - kerberos_delegation=False, - read_timeout_sec=30, - operation_timeout_sec=20, - kerberos_hostname_override=None, - message_encryption='auto', - credssp_disable_tlsv1_2=False, - send_cbt=True, - ): + ssh_conn_id: Optional[str] = None, + endpoint: Optional[str] = None, + remote_host: Optional[str] = None, + remote_port: int = 5985, + transport: str = 'plaintext', + username: Optional[str] = None, + password: Optional[str] = None, + service: str = 'HTTP', + keytab: Optional[str] = None, + ca_trust_path: Optional[str] = None, + cert_pem: Optional[str] = None, + cert_key_pem: Optional[str] = None, + server_cert_validation: str = 'validate', + kerberos_delegation: bool = False, + read_timeout_sec: int = 30, + operation_timeout_sec: int = 20, + kerberos_hostname_override: Optional[str] = None, + message_encryption: Optional[str] = 'auto', + credssp_disable_tlsv1_2: bool = False, + send_cbt: bool = True, + ) -> None: super().__init__() self.ssh_conn_id = ssh_conn_id self.endpoint = endpoint diff --git a/airflow/providers/microsoft/winrm/operators/winrm.py b/airflow/providers/microsoft/winrm/operators/winrm.py index a0c2c76..8e4b507 100644 --- a/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/airflow/providers/microsoft/winrm/operators/winrm.py @@ -18,6 +18,7 @@ import logging from base64 import b64encode +from typing import Optional, Union from winrm.exceptions import WinRMOperationTimeoutError @@ -53,8 +54,15 @@ class WinRMOperator(BaseOperator): @apply_defaults def __init__( - self, *, winrm_hook=None, ssh_conn_id=None, remote_host=None, command=None, timeout=10, **kwargs - ): + self, + *, + winrm_hook: Optional[WinRMHook] = None, + ssh_conn_id: Optional[str] = None, + remote_host: Optional[str] = None, + command: Optional[str] = None, + timeout: int = 10, + **kwargs, + ) -> None: super().__init__(**kwargs) self.winrm_hook = winrm_hook self.ssh_conn_id = ssh_conn_id @@ -62,7 +70,7 @@ class WinRMOperator(BaseOperator): self.command = command self.timeout = timeout - def execute(self, context): + def execute(self, context: dict) -> Union[list, str]: if self.ssh_conn_id and not self.winrm_hook: self.log.info("Hook not found, creating...") self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id) @@ -81,7 +89,9 @@ class WinRMOperator(BaseOperator): # pylint: disable=too-many-nested-blocks try: self.log.info("Running command: '%s'...", self.command) - command_id = self.winrm_hook.winrm_protocol.run_command(winrm_client, self.command) + command_id = self.winrm_hook.winrm_protocol.run_command( # type: ignore[attr-defined] + winrm_client, self.command + ) # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py stdout_buffer = [] @@ -95,7 +105,9 @@ class WinRMOperator(BaseOperator): stderr, return_code, command_done, - ) = self.winrm_hook.winrm_protocol._raw_get_command_output(winrm_client, command_id) + ) = self.winrm_hook.winrm_protocol._raw_get_command_output( # type: ignore[attr-defined] + winrm_client, command_id + ) # Only buffer stdout if we need to so that we minimize memory usage. if self.do_xcom_push: @@ -111,8 +123,10 @@ class WinRMOperator(BaseOperator): # long-running process, just silently retry pass - self.winrm_hook.winrm_protocol.cleanup_command(winrm_client, command_id) - self.winrm_hook.winrm_protocol.close_shell(winrm_client) + self.winrm_hook.winrm_protocol.cleanup_command( # type: ignore[attr-defined] + winrm_client, command_id + ) + self.winrm_hook.winrm_protocol.close_shell(winrm_client) # type: ignore[attr-defined] except Exception as e: raise AirflowException("WinRM operator error: {0}".format(str(e)))