This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0aff69f  Add typing to ImapHook (#9887)
0aff69f is described below

commit 0aff69fbd2f5a09c51f5b503ebf1bb72a26d3290
Author: Darwin Yip <[email protected]>
AuthorDate: Sun Jul 26 20:31:59 2020 -0400

    Add typing to ImapHook (#9887)
---
 airflow/providers/imap/hooks/imap.py | 96 +++++++++++++++++++++---------------
 1 file changed, 57 insertions(+), 39 deletions(-)

diff --git a/airflow/providers/imap/hooks/imap.py 
b/airflow/providers/imap/hooks/imap.py
index b7197bb..60a46f6 100644
--- a/airflow/providers/imap/hooks/imap.py
+++ b/airflow/providers/imap/hooks/imap.py
@@ -24,6 +24,7 @@ import email
 import imaplib
 import os
 import re
+from typing import Any, Iterable, List, Optional, Tuple
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
@@ -41,10 +42,10 @@ class ImapHook(BaseHook):
     :type imap_conn_id: str
     """
 
-    def __init__(self, imap_conn_id='imap_default'):
+    def __init__(self, imap_conn_id: str = 'imap_default') -> None:
         super().__init__()
         self.imap_conn_id = imap_conn_id
-        self.mail_client = None
+        self.mail_client: Optional[imaplib.IMAP4_SSL] = None
 
     def __enter__(self):
         return self.get_conn()
@@ -52,7 +53,7 @@ class ImapHook(BaseHook):
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.mail_client.logout()
 
-    def get_conn(self):
+    def get_conn(self) -> 'ImapHook':
         """
         Login to the mail server.
 
@@ -70,7 +71,12 @@ class ImapHook(BaseHook):
 
         return self
 
-    def has_mail_attachment(self, name, *, check_regex=False, 
mail_folder='INBOX', mail_filter='All'):
+    def has_mail_attachment(self,
+                            name: str,
+                            *,
+                            check_regex: bool = False,
+                            mail_folder: str = 'INBOX',
+                            mail_filter: str = 'All') -> bool:
         """
         Checks the mail folder for mails containing attachments with the given 
name.
 
@@ -94,13 +100,13 @@ class ImapHook(BaseHook):
         return len(mail_attachments) > 0
 
     def retrieve_mail_attachments(self,
-                                  name,
+                                  name: str,
                                   *,
-                                  check_regex=False,
-                                  latest_only=False,
-                                  mail_folder='INBOX',
-                                  mail_filter='All',
-                                  not_found_mode='raise'):
+                                  check_regex: bool = False,
+                                  latest_only: bool = False,
+                                  mail_folder: str = 'INBOX',
+                                  mail_filter: str = 'All',
+                                  not_found_mode: str = 'raise') -> 
List[Tuple]:
         """
         Retrieves mail's attachments in the mail folder by its name.
 
@@ -136,14 +142,14 @@ class ImapHook(BaseHook):
         return mail_attachments
 
     def download_mail_attachments(self,
-                                  name,
-                                  local_output_directory,
+                                  name: str,
+                                  local_output_directory: str,
                                   *,
-                                  check_regex=False,
-                                  latest_only=False,
-                                  mail_folder='INBOX',
-                                  mail_filter='All',
-                                  not_found_mode='raise'):
+                                  check_regex: bool = False,
+                                  latest_only: bool = False,
+                                  mail_folder: str = 'INBOX',
+                                  mail_filter: str = 'All',
+                                  not_found_mode: str = 'raise'):
         """
         Downloads mail's attachments in the mail folder by its name to the 
local directory.
 
@@ -179,7 +185,7 @@ class ImapHook(BaseHook):
 
         self._create_files(mail_attachments, local_output_directory)
 
-    def _handle_not_found_mode(self, not_found_mode):
+    def _handle_not_found_mode(self, not_found_mode: str):
         if not_found_mode == 'raise':
             raise AirflowException('No mail attachments found!')
         if not_found_mode == 'warn':
@@ -189,7 +195,11 @@ class ImapHook(BaseHook):
         else:
             self.log.error('Invalid "not_found_mode" %s', not_found_mode)
 
-    def _retrieve_mails_attachments_by_name(self, name, check_regex, 
latest_only, mail_folder, mail_filter):
+    def _retrieve_mails_attachments_by_name(self, name: str, check_regex: 
bool, latest_only: bool,
+                                            mail_folder: str, mail_filter: 
str) -> List:
+        if not self.mail_client:
+            raise Exception("The 'mail_client' should be initialized before!")
+
         all_matching_attachments = []
 
         self.mail_client.select(mail_folder)
@@ -207,24 +217,29 @@ class ImapHook(BaseHook):
 
         return all_matching_attachments
 
-    def _list_mail_ids_desc(self, mail_filter):
+    def _list_mail_ids_desc(self, mail_filter: str) -> Iterable[str]:
+        if not self.mail_client:
+            raise Exception("The 'mail_client' should be initialized before!")
         _, data = self.mail_client.search(None, mail_filter)
         mail_ids = data[0].split()
         return reversed(mail_ids)
 
-    def _fetch_mail_body(self, mail_id):
+    def _fetch_mail_body(self, mail_id: str) -> str:
+        if not self.mail_client:
+            raise Exception("The 'mail_client' should be initialized before!")
         _, data = self.mail_client.fetch(mail_id, '(RFC822)')
-        mail_body = data[0][1]  # The mail body is always in this specific 
location
-        mail_body_str = mail_body.decode('utf-8')
+        mail_body = data[0][1]  # type: ignore # The mail body is always in 
this specific location
+        mail_body_str = mail_body.decode('utf-8')  # type: ignore
         return mail_body_str
 
-    def _check_mail_body(self, response_mail_body, name, check_regex, 
latest_only):
+    def _check_mail_body(self, response_mail_body: str, name: str, 
check_regex: bool,
+                         latest_only: bool) -> List[Tuple[Any, Any]]:
         mail = Mail(response_mail_body)
         if mail.has_attachments():
             return mail.get_attachments_by_name(name, check_regex, 
find_first=latest_only)
         return []
 
-    def _create_files(self, mail_attachments, local_output_directory):
+    def _create_files(self, mail_attachments: List, local_output_directory: 
str):
         for name, payload in mail_attachments:
             if self._is_symlink(name):
                 self.log.error('Can not create file because it is a symlink!')
@@ -233,19 +248,19 @@ class ImapHook(BaseHook):
             else:
                 self._create_file(name, payload, local_output_directory)
 
-    def _is_symlink(self, name):
+    def _is_symlink(self, name: str):
         # IMPORTANT NOTE: os.path.islink is not working for windows symlinks
         # See: https://stackoverflow.com/a/11068434
         return os.path.islink(name)
 
-    def _is_escaping_current_directory(self, name):
+    def _is_escaping_current_directory(self, name: str):
         return '../' in name
 
-    def _correct_path(self, name, local_output_directory):
+    def _correct_path(self, name: str, local_output_directory: str):
         return local_output_directory + name if 
local_output_directory.endswith('/') \
             else local_output_directory + '/' + name
 
-    def _create_file(self, name, payload, local_output_directory):
+    def _create_file(self, name: str, payload: Any, local_output_directory: 
str):
         file_path = self._correct_path(name, local_output_directory)
 
         with open(file_path, 'wb') as file:
@@ -260,11 +275,11 @@ class Mail(LoggingMixin):
     :type mail_body: str
     """
 
-    def __init__(self, mail_body):
+    def __init__(self, mail_body: str) -> None:
         super().__init__()
         self.mail = email.message_from_string(mail_body)
 
-    def has_attachments(self):
+    def has_attachments(self) -> bool:
         """
         Checks the mail for a attachments.
 
@@ -273,7 +288,10 @@ class Mail(LoggingMixin):
         """
         return self.mail.get_content_maintype() == 'multipart'
 
-    def get_attachments_by_name(self, name, check_regex, find_first=False):
+    def get_attachments_by_name(self,
+                                name: str,
+                                check_regex: bool,
+                                find_first: bool = False) -> List[Tuple[Any, 
Any]]:
         """
         Gets all attachments by name for the mail.
 
@@ -301,7 +319,7 @@ class Mail(LoggingMixin):
 
         return attachments
 
-    def _iterate_attachments(self):
+    def _iterate_attachments(self) -> Iterable['MailPart']:
         for part in self.mail.walk():
             mail_part = MailPart(part)
             if mail_part.is_attachment():
@@ -316,10 +334,10 @@ class MailPart:
     :type part: any
     """
 
-    def __init__(self, part):
+    def __init__(self, part: Any) -> None:
         self.part = part
 
-    def is_attachment(self):
+    def is_attachment(self) -> bool:
         """
         Checks if the part is a valid mail attachment.
 
@@ -328,7 +346,7 @@ class MailPart:
         """
         return self.part.get_content_maintype() != 'multipart' and 
self.part.get('Content-Disposition')
 
-    def has_matching_name(self, name):
+    def has_matching_name(self, name: str) -> Optional[Tuple[Any, Any]]:
         """
         Checks if the given name matches the part's name.
 
@@ -337,9 +355,9 @@ class MailPart:
         :returns: True if it matches the name (including regular expression).
         :rtype: tuple
         """
-        return re.match(name, self.part.get_filename())
+        return re.match(name, self.part.get_filename())  # type: ignore
 
-    def has_equal_name(self, name):
+    def has_equal_name(self, name: str) -> bool:
         """
         Checks if the given name is equal to the part's name.
 
@@ -350,7 +368,7 @@ class MailPart:
         """
         return self.part.get_filename() == name
 
-    def get_file(self):
+    def get_file(self) -> Tuple:
         """
         Gets the file including name and payload.
 

Reply via email to