This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a3adfb8e6 Fix typo in DataSyncHook boto3 methods for create location 
in NFS and EFS (#28948)
3a3adfb8e6 is described below

commit 3a3adfb8e618a7cea376cb5d187fa3e486a9c9ad
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 16 02:33:34 2023 +0400

    Fix typo in DataSyncHook boto3 methods for create location in NFS and EFS 
(#28948)
---
 airflow/providers/amazon/aws/hooks/datasync.py    | 17 ++++++------
 tests/providers/amazon/aws/hooks/test_datasync.py | 34 ++++++++++++++++++++++-
 2 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/datasync.py 
b/airflow/providers/amazon/aws/hooks/datasync.py
index 3a1f6e2c6c..3e7c4e420a 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import time
+from urllib.parse import urlsplit
 
 from airflow.exceptions import AirflowBadRequest, AirflowException, 
AirflowTaskTimeout
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -74,17 +75,17 @@ class DataSyncHook(AwsBaseHook):
         :return: LocationArn of the created Location.
         :raises AirflowException: If location type (prefix from 
``location_uri``) is invalid.
         """
-        typ = location_uri.split(":")[0]
-        if typ == "smb":
+        schema = urlsplit(location_uri).scheme
+        if schema == "smb":
             location = 
self.get_conn().create_location_smb(**create_location_kwargs)
-        elif typ == "s3":
+        elif schema == "s3":
             location = 
self.get_conn().create_location_s3(**create_location_kwargs)
-        elif typ == "nfs":
-            location = 
self.get_conn().create_loction_nfs(**create_location_kwargs)
-        elif typ == "efs":
-            location = 
self.get_conn().create_loction_efs(**create_location_kwargs)
+        elif schema == "nfs":
+            location = 
self.get_conn().create_location_nfs(**create_location_kwargs)
+        elif schema == "efs":
+            location = 
self.get_conn().create_location_efs(**create_location_kwargs)
         else:
-            raise AirflowException(f"Invalid location type: {typ}")
+            raise AirflowException(f"Invalid/Unsupported location type: 
{schema}")
         self._refresh_locations()
         return location["LocationArn"]
 
diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py 
b/tests/providers/amazon/aws/hooks/test_datasync.py
index eeb976e4e0..f7631a73b2 100644
--- a/tests/providers/amazon/aws/hooks/test_datasync.py
+++ b/tests/providers/amazon/aws/hooks/test_datasync.py
@@ -23,7 +23,7 @@ import boto3
 import pytest
 from moto import mock_datasync
 
-from airflow.exceptions import AirflowTaskTimeout
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
 from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
 
 
@@ -97,6 +97,38 @@ class TestDataSyncHookMocked:
         assert not self.hook.tasks
         assert self.hook.wait_interval_seconds == 0
 
+    @pytest.mark.parametrize(
+        "location_uri, expected_method",
+        [
+            pytest.param("smb://spam/egg/", "create_location_smb", id="smb"),
+            pytest.param("s3://foo/bar", "create_location_s3", id="s3"),
+            pytest.param("nfs://server:2049/path", "create_location_nfs", 
id="nfs"),
+            pytest.param("efs://12345.efs.aws-region.amazonaws.com/path", 
"create_location_efs", id="efs"),
+        ],
+    )
+    def test_create_location_method_mapping(self, mock_get_conn, location_uri, 
expected_method):
+        """Test expected location URI and mapping with DataSync.Client 
methods."""
+        mock_get_conn.return_value = self.client
+        assert hasattr(self.client, expected_method), f"{self.client} doesn't 
have method {expected_method}"
+        with mock.patch.object(self.client, expected_method) as m:
+            self.hook.create_location(location_uri, foo="bar", spam="egg")
+            m.assert_called_once_with(foo="bar", spam="egg")
+
+    @pytest.mark.parametrize(
+        "location_uri",
+        [
+            pytest.param("hdfs://namenodehost1/path", id="hdfs"),
+            pytest.param("https://example.org/path";, id="https"),
+            pytest.param("http://example.org/path";, id="http"),
+            pytest.param("lustre://mount/path", id="lustre"),
+        ],
+    )
+    def test_create_location_unknown_type(self, mock_get_conn, location_uri):
+        """Test unsupported location URI."""
+        mock_get_conn.return_value = mock.MagicMock()
+        with pytest.raises(AirflowException, match="Invalid/Unsupported 
location type: .*"):
+            self.hook.create_location(location_uri, foo="bar", spam="egg")
+
     def test_create_location_smb(self, mock_get_conn):
         # ### Configure mock:
         mock_get_conn.return_value = self.client

Reply via email to