This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 0aff69f Add typing to ImapHook (#9887)
0aff69f is described below
commit 0aff69fbd2f5a09c51f5b503ebf1bb72a26d3290
Author: Darwin Yip <[email protected]>
AuthorDate: Sun Jul 26 20:31:59 2020 -0400
Add typing to ImapHook (#9887)
---
airflow/providers/imap/hooks/imap.py | 96 +++++++++++++++++++++---------------
1 file changed, 57 insertions(+), 39 deletions(-)
diff --git a/airflow/providers/imap/hooks/imap.py
b/airflow/providers/imap/hooks/imap.py
index b7197bb..60a46f6 100644
--- a/airflow/providers/imap/hooks/imap.py
+++ b/airflow/providers/imap/hooks/imap.py
@@ -24,6 +24,7 @@ import email
import imaplib
import os
import re
+from typing import Any, Iterable, List, Optional, Tuple
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
@@ -41,10 +42,10 @@ class ImapHook(BaseHook):
:type imap_conn_id: str
"""
- def __init__(self, imap_conn_id='imap_default'):
+ def __init__(self, imap_conn_id: str = 'imap_default') -> None:
super().__init__()
self.imap_conn_id = imap_conn_id
- self.mail_client = None
+ self.mail_client: Optional[imaplib.IMAP4_SSL] = None
def __enter__(self):
return self.get_conn()
@@ -52,7 +53,7 @@ class ImapHook(BaseHook):
def __exit__(self, exc_type, exc_val, exc_tb):
self.mail_client.logout()
- def get_conn(self):
+ def get_conn(self) -> 'ImapHook':
"""
Login to the mail server.
@@ -70,7 +71,12 @@ class ImapHook(BaseHook):
return self
- def has_mail_attachment(self, name, *, check_regex=False,
mail_folder='INBOX', mail_filter='All'):
+ def has_mail_attachment(self,
+ name: str,
+ *,
+ check_regex: bool = False,
+ mail_folder: str = 'INBOX',
+ mail_filter: str = 'All') -> bool:
"""
Checks the mail folder for mails containing attachments with the given
name.
@@ -94,13 +100,13 @@ class ImapHook(BaseHook):
return len(mail_attachments) > 0
def retrieve_mail_attachments(self,
- name,
+ name: str,
*,
- check_regex=False,
- latest_only=False,
- mail_folder='INBOX',
- mail_filter='All',
- not_found_mode='raise'):
+ check_regex: bool = False,
+ latest_only: bool = False,
+ mail_folder: str = 'INBOX',
+ mail_filter: str = 'All',
+ not_found_mode: str = 'raise') ->
List[Tuple]:
"""
Retrieves mail's attachments in the mail folder by its name.
@@ -136,14 +142,14 @@ class ImapHook(BaseHook):
return mail_attachments
def download_mail_attachments(self,
- name,
- local_output_directory,
+ name: str,
+ local_output_directory: str,
*,
- check_regex=False,
- latest_only=False,
- mail_folder='INBOX',
- mail_filter='All',
- not_found_mode='raise'):
+ check_regex: bool = False,
+ latest_only: bool = False,
+ mail_folder: str = 'INBOX',
+ mail_filter: str = 'All',
+ not_found_mode: str = 'raise'):
"""
Downloads mail's attachments in the mail folder by its name to the
local directory.
@@ -179,7 +185,7 @@ class ImapHook(BaseHook):
self._create_files(mail_attachments, local_output_directory)
- def _handle_not_found_mode(self, not_found_mode):
+ def _handle_not_found_mode(self, not_found_mode: str):
if not_found_mode == 'raise':
raise AirflowException('No mail attachments found!')
if not_found_mode == 'warn':
@@ -189,7 +195,11 @@ class ImapHook(BaseHook):
else:
self.log.error('Invalid "not_found_mode" %s', not_found_mode)
- def _retrieve_mails_attachments_by_name(self, name, check_regex,
latest_only, mail_folder, mail_filter):
+ def _retrieve_mails_attachments_by_name(self, name: str, check_regex:
bool, latest_only: bool,
+ mail_folder: str, mail_filter:
str) -> List:
+ if not self.mail_client:
+ raise Exception("The 'mail_client' should be initialized before!")
+
all_matching_attachments = []
self.mail_client.select(mail_folder)
@@ -207,24 +217,29 @@ class ImapHook(BaseHook):
return all_matching_attachments
- def _list_mail_ids_desc(self, mail_filter):
+ def _list_mail_ids_desc(self, mail_filter: str) -> Iterable[str]:
+ if not self.mail_client:
+ raise Exception("The 'mail_client' should be initialized before!")
_, data = self.mail_client.search(None, mail_filter)
mail_ids = data[0].split()
return reversed(mail_ids)
- def _fetch_mail_body(self, mail_id):
+ def _fetch_mail_body(self, mail_id: str) -> str:
+ if not self.mail_client:
+ raise Exception("The 'mail_client' should be initialized before!")
_, data = self.mail_client.fetch(mail_id, '(RFC822)')
- mail_body = data[0][1] # The mail body is always in this specific
location
- mail_body_str = mail_body.decode('utf-8')
+ mail_body = data[0][1] # type: ignore # The mail body is always in
this specific location
+ mail_body_str = mail_body.decode('utf-8') # type: ignore
return mail_body_str
- def _check_mail_body(self, response_mail_body, name, check_regex,
latest_only):
+ def _check_mail_body(self, response_mail_body: str, name: str,
check_regex: bool,
+ latest_only: bool) -> List[Tuple[Any, Any]]:
mail = Mail(response_mail_body)
if mail.has_attachments():
return mail.get_attachments_by_name(name, check_regex,
find_first=latest_only)
return []
- def _create_files(self, mail_attachments, local_output_directory):
+ def _create_files(self, mail_attachments: List, local_output_directory:
str):
for name, payload in mail_attachments:
if self._is_symlink(name):
self.log.error('Can not create file because it is a symlink!')
@@ -233,19 +248,19 @@ class ImapHook(BaseHook):
else:
self._create_file(name, payload, local_output_directory)
- def _is_symlink(self, name):
+ def _is_symlink(self, name: str):
# IMPORTANT NOTE: os.path.islink is not working for windows symlinks
# See: https://stackoverflow.com/a/11068434
return os.path.islink(name)
- def _is_escaping_current_directory(self, name):
+ def _is_escaping_current_directory(self, name: str):
return '../' in name
- def _correct_path(self, name, local_output_directory):
+ def _correct_path(self, name: str, local_output_directory: str):
return local_output_directory + name if
local_output_directory.endswith('/') \
else local_output_directory + '/' + name
- def _create_file(self, name, payload, local_output_directory):
+ def _create_file(self, name: str, payload: Any, local_output_directory:
str):
file_path = self._correct_path(name, local_output_directory)
with open(file_path, 'wb') as file:
@@ -260,11 +275,11 @@ class Mail(LoggingMixin):
:type mail_body: str
"""
- def __init__(self, mail_body):
+ def __init__(self, mail_body: str) -> None:
super().__init__()
self.mail = email.message_from_string(mail_body)
- def has_attachments(self):
+ def has_attachments(self) -> bool:
"""
Checks the mail for a attachments.
@@ -273,7 +288,10 @@ class Mail(LoggingMixin):
"""
return self.mail.get_content_maintype() == 'multipart'
- def get_attachments_by_name(self, name, check_regex, find_first=False):
+ def get_attachments_by_name(self,
+ name: str,
+ check_regex: bool,
+ find_first: bool = False) -> List[Tuple[Any,
Any]]:
"""
Gets all attachments by name for the mail.
@@ -301,7 +319,7 @@ class Mail(LoggingMixin):
return attachments
- def _iterate_attachments(self):
+ def _iterate_attachments(self) -> Iterable['MailPart']:
for part in self.mail.walk():
mail_part = MailPart(part)
if mail_part.is_attachment():
@@ -316,10 +334,10 @@ class MailPart:
:type part: any
"""
- def __init__(self, part):
+ def __init__(self, part: Any) -> None:
self.part = part
- def is_attachment(self):
+ def is_attachment(self) -> bool:
"""
Checks if the part is a valid mail attachment.
@@ -328,7 +346,7 @@ class MailPart:
"""
return self.part.get_content_maintype() != 'multipart' and
self.part.get('Content-Disposition')
- def has_matching_name(self, name):
+ def has_matching_name(self, name: str) -> Optional[Tuple[Any, Any]]:
"""
Checks if the given name matches the part's name.
@@ -337,9 +355,9 @@ class MailPart:
:returns: True if it matches the name (including regular expression).
:rtype: tuple
"""
- return re.match(name, self.part.get_filename())
+ return re.match(name, self.part.get_filename()) # type: ignore
- def has_equal_name(self, name):
+ def has_equal_name(self, name: str) -> bool:
"""
Checks if the given name is equal to the part's name.
@@ -350,7 +368,7 @@ class MailPart:
"""
return self.part.get_filename() == name
- def get_file(self):
+ def get_file(self) -> Tuple:
"""
Gets the file including name and payload.