malthe commented on a change in pull request #17273:
URL: https://github.com/apache/airflow/pull/17273#discussion_r678042850
##########
File path: airflow/providers/samba/hooks/samba.py
##########
@@ -16,67 +16,231 @@
# specific language governing permissions and limitations
# under the License.
-import os
+from functools import wraps
+from shutil import copyfileobj
+from typing import Optional
-from smbclient import SambaClient
+import smbclient
from airflow.hooks.base import BaseHook
class SambaHook(BaseHook):
- """Allows for interaction with an samba server."""
+ """Allows for interaction with a Samba server.
+
+ :param samba_conn_id: The connection id reference.
+ :type samba_conn_id: str
+ :param share:
+ An optional share name. If this is unset then the "schema" field of
+ the connection is used in its place.
+ :type share: str
+ """
conn_name_attr = 'samba_conn_id'
default_conn_name = 'samba_default'
conn_type = 'samba'
hook_name = 'Samba'
- def __init__(self, samba_conn_id: str = default_conn_name) -> None:
+ def __init__(self, samba_conn_id: str = default_conn_name, share:
Optional[str] = None) -> None:
super().__init__()
- self.conn = self.get_connection(samba_conn_id)
+ conn = self.get_connection(samba_conn_id)
+
+ if not conn.login:
+ self.log.info("Login not provided")
+
+ if not conn.password:
+ self.log.info("Password not provided")
+
+ self._host = conn.host
+ self._share = share or conn.schema
+ self._connection_cache = connection_cache = {}
+ self._conn_kwargs = {
+ "username": conn.login,
+ "password": conn.password,
+ "port": conn.port or 445,
+ "connection_cache": connection_cache,
+ }
+
+ def __enter__(self):
+ # This immediately connects to the host (which can be
+ # perceived as a benefit), but also help work around an issue:
+ #
+ # https://github.com/jborean93/smbprotocol/issues/109.
+ smbclient.register_session(self._host, **self._conn_kwargs)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ for host, connection in self._connection_cache.items():
+ self.log.info("Disconnecting from %s", host)
+ connection.disconnect()
+ self._connection_cache.clear()
+
+ @property
+ def _base_url(self):
+ return f"//{self._host}/{self._share}"
+
+ @wraps(smbclient.link)
+ def link(self, src, dst, follow_symlinks=True):
+ return smbclient.link(
+ self._base_url + "/" + src,
Review comment:
Sounds good – I did exactly that.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]