This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e51608  Fix mypy for providers: elasticsearch, oracle, yandex (#20344)
6e51608 is described below

commit 6e51608f28f4c769c019624ea0caaa0c6e671f80
Author: Dmytro Kazanzhy <[email protected]>
AuthorDate: Thu Dec 16 22:57:03 2021 +0200

    Fix mypy for providers: elasticsearch, oracle, yandex (#20344)
---
 airflow/providers/elasticsearch/log/es_task_handler.py | 16 +++++++++-------
 airflow/providers/oracle/hooks/oracle.py               | 10 ++++------
 airflow/providers/oracle/operators/oracle.py           |  2 +-
 airflow/providers/yandex/hooks/yandex.py               |  3 ++-
 4 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py 
b/airflow/providers/elasticsearch/log/es_task_handler.py
index cd08971..c2b041e 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -22,7 +22,7 @@ from collections import defaultdict
 from datetime import datetime
 from operator import attrgetter
 from time import time
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Union
 from urllib.parse import quote
 
 # Using `from elasticsearch import *` would break elasticsearch mocking used 
in unit test.
@@ -97,9 +97,11 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
         self.json_fields = [label.strip() for label in json_fields.split(",")]
         self.host_field = host_field
         self.offset_field = offset_field
-        self.handler = None
         self.context_set = False
 
+        self.formatter: logging.Formatter
+        self.handler: Union[logging.FileHandler, logging.StreamHandler]  # 
type: ignore[assignment]
+
     def _render_log_id(self, ti: TaskInstance, try_number: int) -> str:
         dag_run = ti.dag_run
 
@@ -294,9 +296,9 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
                 # already been initialized
                 return
 
-            self.handler = logging.StreamHandler(stream=sys.__stdout__)  # 
type: ignore
-            self.handler.setLevel(self.level)  # type: ignore
-            self.handler.setFormatter(self.formatter)  # type: ignore
+            self.handler = logging.StreamHandler(stream=sys.__stdout__)
+            self.handler.setLevel(self.level)
+            self.handler.setFormatter(self.formatter)
         else:
             super().set_context(ti)
         self.context_set = True
@@ -320,8 +322,8 @@ class ElasticsearchTaskHandler(FileTaskHandler, 
ExternalLoggingMixin, LoggingMix
 
         # Reopen the file stream, because FileHandler.close() would be called
         # first in logging.shutdown() and the stream in it would be set to 
None.
-        if self.handler.stream is None or self.handler.stream.closed:
-            self.handler.stream = self.handler._open()
+        if self.handler.stream is None or self.handler.stream.closed:  # type: 
ignore[attr-defined]
+            self.handler.stream = self.handler._open()  # type: 
ignore[union-attr]
 
         # Mark the end of file using end of log mark,
         # so we know where to stop while auto-tailing.
diff --git a/airflow/providers/oracle/hooks/oracle.py 
b/airflow/providers/oracle/hooks/oracle.py
index f079197..b5dd453 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -17,7 +17,7 @@
 # under the License.
 
 from datetime import datetime
-from typing import Dict, List, Optional, TypeVar
+from typing import Dict, List, Optional, Union
 
 import cx_Oracle
 import numpy
@@ -26,8 +26,6 @@ from airflow.hooks.dbapi import DbApiHook
 
 PARAM_TYPES = {bool, float, int, str}
 
-ParameterType = TypeVar('ParameterType', Dict, List, None)
-
 
 def _map_param(value):
     if value in PARAM_TYPES:
@@ -284,8 +282,8 @@ class OracleHook(DbApiHook):
         self,
         identifier: str,
         autocommit: bool = False,
-        parameters: ParameterType = None,
-    ) -> ParameterType:
+        parameters: Optional[Union[List, Dict]] = None,
+    ) -> Optional[Union[List, Dict]]:
         """
         Call the stored procedure identified by the provided string.
 
@@ -301,7 +299,7 @@ class OracleHook(DbApiHook):
         for further reference.
         """
         if parameters is None:
-            parameters = ()
+            parameters = []
 
         args = ",".join(
             f":{name}"
diff --git a/airflow/providers/oracle/operators/oracle.py 
b/airflow/providers/oracle/operators/oracle.py
index b80d570..175859e 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -95,7 +95,7 @@ class OracleStoredProcedureOperator(BaseOperator):
         self.procedure = procedure
         self.parameters = parameters
 
-    def execute(self, context) -> None:
+    def execute(self, context) -> Optional[Union[List, Dict]]:
         self.log.info('Executing: %s', self.procedure)
         hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
         return hook.callproc(self.procedure, autocommit=True, 
parameters=self.parameters)
diff --git a/airflow/providers/yandex/hooks/yandex.py 
b/airflow/providers/yandex/hooks/yandex.py
index aac8d23..89d4f24 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -88,11 +88,12 @@ class YandexCloudBaseHook(BaseHook):
 
         try:
             manager = ProvidersManager()
-            provider_name = manager.hooks[cls.conn_type].package_name
+            provider_name = manager.hooks[cls.conn_type].package_name  # type: 
ignore[union-attr]
             provider = manager.providers[provider_name]
             return f'apache-airflow/{airflow.__version__} 
{provider_name}/{provider.version}'
         except KeyError:
             warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in 
airflow.ProviderManager")
+            return None
 
     @staticmethod
     def get_ui_field_behaviour() -> Dict:

Reply via email to