This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 4eddce2 Add typing for grpc provider (#9884)
4eddce2 is described below
commit 4eddce22a3e0eb605f5661204a005262bbaa54cd
Author: chipmyersjr <[email protected]>
AuthorDate: Tue Jul 21 10:13:14 2020 -0700
Add typing for grpc provider (#9884)
---
airflow/providers/grpc/hooks/grpc.py | 21 +++++++++++++--------
airflow/providers/grpc/operators/grpc.py | 28 +++++++++++++++-------------
2 files changed, 28 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/grpc/hooks/grpc.py
b/airflow/providers/grpc/hooks/grpc.py
index 8697c9d..0737e6f 100644
--- a/airflow/providers/grpc/hooks/grpc.py
+++ b/airflow/providers/grpc/hooks/grpc.py
@@ -16,6 +16,7 @@
# under the License.
"""GRPC Hook"""
+from typing import Callable, Generator, List, Optional
import grpc
from google import auth as google_auth
@@ -45,7 +46,10 @@ class GrpcHook(BaseHook):
its only arg. Could be partial or lambda.
"""
- def __init__(self, grpc_conn_id, interceptors=None,
custom_connection_func=None):
+ def __init__(self,
+ grpc_conn_id: str,
+ interceptors: Optional[List[Callable]] = None,
+ custom_connection_func: Optional[Callable] = None) -> None:
super().__init__()
self.grpc_conn_id = grpc_conn_id
self.conn = self.get_connection(self.grpc_conn_id)
@@ -53,7 +57,7 @@ class GrpcHook(BaseHook):
self.interceptors = interceptors if interceptors else []
self.custom_connection_func = custom_connection_func
- def get_conn(self):
+ def get_conn(self) -> grpc.Channel:
base_url = self.conn.host
if self.conn.port:
@@ -96,7 +100,11 @@ class GrpcHook(BaseHook):
return channel
- def run(self, stub_class, call_func, streaming=False, data=None):
+ def run(self,
+ stub_class: Callable,
+ call_func: str,
+ streaming: bool = False,
+ data: Optional[dict] = None) -> Generator:
"""
Call gRPC function and yield response to caller
"""
@@ -123,7 +131,7 @@ class GrpcHook(BaseHook):
)
raise ex
- def _get_field(self, field_name, default=None):
+ def _get_field(self, field_name: str) -> str:
"""
Fetches a field from extras, and returns it. This is some Airflow
magic. The grpc hook type adds custom UI elements
@@ -131,7 +139,4 @@ class GrpcHook(BaseHook):
They get formatted as shown below.
"""
full_field_name = 'extra__grpc__{}'.format(field_name)
- if full_field_name in self.extras:
- return self.extras[full_field_name]
- else:
- return default
+ return self.extras[full_field_name]
diff --git a/airflow/providers/grpc/operators/grpc.py
b/airflow/providers/grpc/operators/grpc.py
index ad6374e..107efbc 100644
--- a/airflow/providers/grpc/operators/grpc.py
+++ b/airflow/providers/grpc/operators/grpc.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
+from typing import Any, Callable, Dict, List, Optional
+
from airflow.models import BaseOperator
from airflow.providers.grpc.hooks.grpc import GrpcHook
from airflow.utils.decorators import apply_defaults
@@ -52,16 +54,16 @@ class GrpcOperator(BaseOperator):
@apply_defaults
def __init__(self,
- stub_class,
- call_func,
- grpc_conn_id="grpc_default",
- data=None,
- interceptors=None,
- custom_connection_func=None,
- streaming=False,
- response_callback=None,
- log_response=False,
- *args, **kwargs):
+ stub_class: Callable,
+ call_func: str,
+ grpc_conn_id: str = "grpc_default",
+ data: Optional[dict] = None,
+ interceptors: Optional[List[Callable]] = None,
+ custom_connection_func: Optional[Callable] = None,
+ streaming: bool = False,
+ response_callback: Optional[Callable] = None,
+ log_response: bool = False,
+ *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.stub_class = stub_class
self.call_func = call_func
@@ -73,14 +75,14 @@ class GrpcOperator(BaseOperator):
self.log_response = log_response
self.response_callback = response_callback
- def _get_grpc_hook(self):
+ def _get_grpc_hook(self) -> GrpcHook:
return GrpcHook(
self.grpc_conn_id,
interceptors=self.interceptors,
custom_connection_func=self.custom_connection_func
)
- def execute(self, context):
+ def execute(self, context: Dict) -> None:
hook = self._get_grpc_hook()
self.log.info("Calling gRPC service")
@@ -90,7 +92,7 @@ class GrpcOperator(BaseOperator):
for response in responses:
self._handle_response(response, context)
- def _handle_response(self, response, context):
+ def _handle_response(self, response: Any, context: Dict) -> None:
if self.log_response:
self.log.info(repr(response))
if self.response_callback: