Fokko closed pull request #2372: [AIRFLOW-393] Add callback for FTP downloads
URL: https://github.com/apache/incubator-airflow/pull/2372
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/airflow/contrib/hooks/ftp_hook.py
b/airflow/contrib/hooks/ftp_hook.py
index 8beefb3729..03849012a3 100644
--- a/airflow/contrib/hooks/ftp_hook.py
+++ b/airflow/contrib/hooks/ftp_hook.py
@@ -148,7 +148,11 @@ def delete_directory(self, path):
conn = self.get_conn()
conn.rmd(path)
- def retrieve_file(self, remote_full_path, local_full_path_or_buffer):
+ def retrieve_file(
+ self,
+ remote_full_path,
+ local_full_path_or_buffer,
+ callback=None):
"""
Transfers the remote file to a local location.
@@ -161,23 +165,59 @@ def retrieve_file(self, remote_full_path,
local_full_path_or_buffer):
:param local_full_path_or_buffer: full path to the local file or a
file-like buffer
:type local_full_path_or_buffer: str or file-like buffer
+ :param callback: callback which is called each time a block of data
+ is read. if you do not use a callback, these blocks will be written
+ to the file or buffer passed in. if you do pass in a callback, note
+ that writing to a file or buffer will need to be handled inside the
+ callback.
+ [default: output_handle.write()]
+ :type callback: callable
+
+ Example::
+ hook = FTPHook(ftp_conn_id='my_conn')
+
+ remote_path = '/path/to/remote/file'
+ local_path = '/path/to/local/file'
+
+ # with a custom callback (in this case displaying progress on each
read)
+ def print_progress(percent_progress):
+ self.log.info('Percent Downloaded: %s%%' % percent_progress)
+
+ total_downloaded = 0
+ total_file_size = hook.get_size(remote_path)
+ output_handle = open(local_path, 'wb')
+ def write_to_file_with_progress(data):
+ total_downloaded += len(data)
+ output_handle.write(data)
+ percent_progress = (total_downloaded / total_file_size) * 100
+ print_progress(percent_progress)
+ hook.retrieve_file(remote_path, None,
callback=write_to_file_with_progress)
+
+ # without a custom callback data is written to the local_path
+ hook.retrieve_file(remote_path, local_path)
"""
conn = self.get_conn()
is_path = isinstance(local_full_path_or_buffer, basestring)
- if is_path:
- output_handle = open(local_full_path_or_buffer, 'wb')
+ # without a callback, default to writing to a user-provided file or
+ # file-like buffer
+ if not callback:
+ if is_path:
+ output_handle = open(local_full_path_or_buffer, 'wb')
+ else:
+ output_handle = local_full_path_or_buffer
+ callback = output_handle.write
else:
- output_handle = local_full_path_or_buffer
+ output_handle = None
remote_path, remote_file_name = os.path.split(remote_full_path)
conn.cwd(remote_path)
self.log.info('Retrieving file from FTP: %s', remote_full_path)
- conn.retrbinary('RETR %s' % remote_file_name, output_handle.write)
+ conn.retrbinary('RETR %s' % remote_file_name, callback)
self.log.info('Finished retrieving file from FTP: %s',
remote_full_path)
- if is_path:
+ if is_path and output_handle:
output_handle.close()
def store_file(self, remote_full_path, local_full_path_or_buffer):
@@ -230,6 +270,12 @@ def rename(self, from_name, to_name):
return conn.rename(from_name, to_name)
def get_mod_time(self, path):
+ """
+ Returns a datetime object representing the last time the file was
modified
+
+ :param path: remote file path
+ :type path: string
+ """
conn = self.get_conn()
ftp_mdtm = conn.sendcmd('MDTM ' + path)
time_val = ftp_mdtm[4:]
@@ -239,6 +285,16 @@ def get_mod_time(self, path):
except ValueError:
return datetime.datetime.strptime(time_val, '%Y%m%d%H%M%S')
+ def get_size(self, path):
+ """
+ Returns the size of a file (in bytes)
+
+ :param path: remote file path
+ :type path: string
+ """
+ conn = self.get_conn()
+ return conn.size(path)
+
class FTPSHook(FTPHook):
diff --git a/tests/contrib/hooks/test_ftp_hook.py
b/tests/contrib/hooks/test_ftp_hook.py
index 8b9ae2cd59..1274990827 100644
--- a/tests/contrib/hooks/test_ftp_hook.py
+++ b/tests/contrib/hooks/test_ftp_hook.py
@@ -19,6 +19,7 @@
#
import mock
+import six
import unittest
from airflow.contrib.hooks import ftp_hook as fh
@@ -101,6 +102,28 @@ def test_mod_time_micro(self):
self.conn_mock.sendcmd.assert_called_once_with('MDTM ' + path)
+ def test_get_size(self):
+ self.conn_mock.size.return_value = 1942
+
+ path = '/path/file'
+ with fh.FTPHook() as ftp_hook:
+ ftp_hook.get_size(path)
+
+ self.conn_mock.size.assert_called_once_with(path)
+
+ def test_retrieve_file(self):
+ _buffer = six.StringIO('buffer')
+ with fh.FTPHook() as ftp_hook:
+ ftp_hook.retrieve_file(self.path, _buffer)
+ self.conn_mock.retrbinary.assert_called_once_with('RETR path',
_buffer.write)
+
+ def test_retrieve_file_with_callback(self):
+ func = mock.Mock()
+ _buffer = six.StringIO('buffer')
+ with fh.FTPHook() as ftp_hook:
+ ftp_hook.retrieve_file(self.path, _buffer, callback=func)
+ self.conn_mock.retrbinary.assert_called_once_with('RETR path', func)
+
if __name__ == '__main__':
unittest.main()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services