uranusjr commented on a change in pull request #17273:
URL: https://github.com/apache/airflow/pull/17273#discussion_r678038843
##########
File path: airflow/providers/samba/hooks/samba.py
##########
@@ -16,67 +16,231 @@
# specific language governing permissions and limitations
# under the License.
-import os
+from functools import wraps
+from shutil import copyfileobj
+from typing import Optional
-from smbclient import SambaClient
+import smbclient
from airflow.hooks.base import BaseHook
class SambaHook(BaseHook):
- """Allows for interaction with an samba server."""
+ """Allows for interaction with a Samba server.
+
+ :param samba_conn_id: The connection id reference.
+ :type samba_conn_id: str
+ :param share:
+ An optional share name. If this is unset then the "schema" field of
+ the connection is used in its place.
+ :type share: str
+ """
conn_name_attr = 'samba_conn_id'
default_conn_name = 'samba_default'
conn_type = 'samba'
hook_name = 'Samba'
- def __init__(self, samba_conn_id: str = default_conn_name) -> None:
+ def __init__(self, samba_conn_id: str = default_conn_name, share:
Optional[str] = None) -> None:
super().__init__()
- self.conn = self.get_connection(samba_conn_id)
+ conn = self.get_connection(samba_conn_id)
+
+ if not conn.login:
+ self.log.info("Login not provided")
+
+ if not conn.password:
+ self.log.info("Password not provided")
+
+ self._host = conn.host
+ self._share = share or conn.schema
+ self._connection_cache = connection_cache = {}
+ self._conn_kwargs = {
+ "username": conn.login,
+ "password": conn.password,
+ "port": conn.port or 445,
+ "connection_cache": connection_cache,
+ }
+
+ def __enter__(self):
+ # This immediately connects to the host (which can be
+ # perceived as a benefit), but also help work around an issue:
+ #
+ # https://github.com/jborean93/smbprotocol/issues/109.
+ smbclient.register_session(self._host, **self._conn_kwargs)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ for host, connection in self._connection_cache.items():
+ self.log.info("Disconnecting from %s", host)
+ connection.disconnect()
+ self._connection_cache.clear()
+
+ @property
+ def _base_url(self):
+ return f"//{self._host}/{self._share}"
+
+ @wraps(smbclient.link)
+ def link(self, src, dst, follow_symlinks=True):
+ return smbclient.link(
+ self._base_url + "/" + src,
Review comment:
Probably better to use `posixpath.join()` so users don’t need to deal
with double slash issues.
It may be even better to change `_base_url` to something like
```python
def _format_url(self, endpoint):
return f"//{posixpath.join(self._host, self._share, endpoint)}"
```
since the value of `_base_url` is never used along but always used to
generate a full path (if I didn’t miss anything).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]