This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3a3adfb8e6 Fix typo in DataSyncHook boto3 methods for create location
in NFS and EFS (#28948)
3a3adfb8e6 is described below
commit 3a3adfb8e618a7cea376cb5d187fa3e486a9c9ad
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 16 02:33:34 2023 +0400
Fix typo in DataSyncHook boto3 methods for create location in NFS and EFS
(#28948)
---
airflow/providers/amazon/aws/hooks/datasync.py | 17 ++++++------
tests/providers/amazon/aws/hooks/test_datasync.py | 34 ++++++++++++++++++++++-
2 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/datasync.py
b/airflow/providers/amazon/aws/hooks/datasync.py
index 3a1f6e2c6c..3e7c4e420a 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import time
+from urllib.parse import urlsplit
from airflow.exceptions import AirflowBadRequest, AirflowException,
AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -74,17 +75,17 @@ class DataSyncHook(AwsBaseHook):
:return: LocationArn of the created Location.
:raises AirflowException: If location type (prefix from
``location_uri``) is invalid.
"""
- typ = location_uri.split(":")[0]
- if typ == "smb":
+ schema = urlsplit(location_uri).scheme
+ if schema == "smb":
location =
self.get_conn().create_location_smb(**create_location_kwargs)
- elif typ == "s3":
+ elif schema == "s3":
location =
self.get_conn().create_location_s3(**create_location_kwargs)
- elif typ == "nfs":
- location =
self.get_conn().create_loction_nfs(**create_location_kwargs)
- elif typ == "efs":
- location =
self.get_conn().create_loction_efs(**create_location_kwargs)
+ elif schema == "nfs":
+ location =
self.get_conn().create_location_nfs(**create_location_kwargs)
+ elif schema == "efs":
+ location =
self.get_conn().create_location_efs(**create_location_kwargs)
else:
- raise AirflowException(f"Invalid location type: {typ}")
+ raise AirflowException(f"Invalid/Unsupported location type:
{schema}")
self._refresh_locations()
return location["LocationArn"]
diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py
b/tests/providers/amazon/aws/hooks/test_datasync.py
index eeb976e4e0..f7631a73b2 100644
--- a/tests/providers/amazon/aws/hooks/test_datasync.py
+++ b/tests/providers/amazon/aws/hooks/test_datasync.py
@@ -23,7 +23,7 @@ import boto3
import pytest
from moto import mock_datasync
-from airflow.exceptions import AirflowTaskTimeout
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
@@ -97,6 +97,38 @@ class TestDataSyncHookMocked:
assert not self.hook.tasks
assert self.hook.wait_interval_seconds == 0
+ @pytest.mark.parametrize(
+ "location_uri, expected_method",
+ [
+ pytest.param("smb://spam/egg/", "create_location_smb", id="smb"),
+ pytest.param("s3://foo/bar", "create_location_s3", id="s3"),
+ pytest.param("nfs://server:2049/path", "create_location_nfs",
id="nfs"),
+ pytest.param("efs://12345.efs.aws-region.amazonaws.com/path",
"create_location_efs", id="efs"),
+ ],
+ )
+ def test_create_location_method_mapping(self, mock_get_conn, location_uri,
expected_method):
+ """Test expected location URI and mapping with DataSync.Client
methods."""
+ mock_get_conn.return_value = self.client
+ assert hasattr(self.client, expected_method), f"{self.client} doesn't
have method {expected_method}"
+ with mock.patch.object(self.client, expected_method) as m:
+ self.hook.create_location(location_uri, foo="bar", spam="egg")
+ m.assert_called_once_with(foo="bar", spam="egg")
+
+ @pytest.mark.parametrize(
+ "location_uri",
+ [
+ pytest.param("hdfs://namenodehost1/path", id="hdfs"),
+ pytest.param("https://example.org/path", id="https"),
+ pytest.param("http://example.org/path", id="http"),
+ pytest.param("lustre://mount/path", id="lustre"),
+ ],
+ )
+ def test_create_location_unknown_type(self, mock_get_conn, location_uri):
+ """Test unsupported location URI."""
+ mock_get_conn.return_value = mock.MagicMock()
+ with pytest.raises(AirflowException, match="Invalid/Unsupported
location type: .*"):
+ self.hook.create_location(location_uri, foo="bar", spam="egg")
+
def test_create_location_smb(self, mock_get_conn):
# ### Configure mock:
mock_get_conn.return_value = self.client