This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 94f2ce9342 Add partition related methods to GlueCatalogHook: (#23857)
94f2ce9342 is described below

commit 94f2ce9342d995f1d2eb00e6a9444e57c90e4963
Author: Guilherme Martins Crocetti 
<[email protected]>
AuthorDate: Mon May 30 16:26:40 2022 -0300

    Add partition related methods to GlueCatalogHook: (#23857)
    
    * "get_partition" to retrieve a Partition
    * "create_partition" to create a Partition
---
 airflow/providers/amazon/aws/hooks/glue_catalog.py | 58 +++++++++++++++++++++-
 .../amazon/aws/hooks/test_glue_catalog.py          | 56 +++++++++++++++++++++
 2 files changed, 113 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py 
b/airflow/providers/amazon/aws/hooks/glue_catalog.py
index fc9c353e08..e77916d09e 100644
--- a/airflow/providers/amazon/aws/hooks/glue_catalog.py
+++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py
@@ -18,8 +18,11 @@
 
 """This module contains AWS Glue Catalog Hook"""
 import warnings
-from typing import Optional, Set
+from typing import Dict, List, Optional, Set
 
+from botocore.exceptions import ClientError
+
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 
 
@@ -123,6 +126,59 @@ class GlueCatalogHook(AwsBaseHook):
 
         return table['StorageDescriptor']['Location']
 
+    def get_partition(self, database_name: str, table_name: str, 
partition_values: List[str]) -> Dict:
+        """
+        Gets a Partition
+
+        :param database_name: Database name
+        :param table_name: Database's Table name
+        :param partition_values: List of utf-8 strings that define the 
partition
+            Please see official AWS documentation for further information.
+            
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition
+
+        :rtype: dict
+
+        :raises: AirflowException
+
+        >>> hook = GlueCatalogHook()
+        >>> partition = hook.get_partition('db', 'table', ['string'])
+        >>> partition['Values']
+        """
+        try:
+            response = self.get_conn().get_partition(
+                DatabaseName=database_name, TableName=table_name, 
PartitionValues=partition_values
+            )
+            return response["Partition"]
+        except ClientError as e:
+            self.log.error("Client error: %s", e)
+            raise AirflowException("AWS request failed, check logs for more 
info")
+
+    def create_partition(self, database_name: str, table_name: str, 
partition_input: Dict) -> Dict:
+        """
+        Creates a new Partition
+
+        :param database_name: Database name
+        :param table_name: Database's Table name
+        :param partition_input: Definition of how the partition is created
+            Please see official AWS documentation for further information.
+            
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition
+
+        :rtype: dict
+
+        :raises: AirflowException
+
+        >>> hook = GlueCatalogHook()
+        >>> partition_input = {"Values": []}
+        >>> hook.create_partition(database_name="db", table_name="table", 
partition_input=partition_input)
+        """
+        try:
+            return self.get_conn().create_partition(
+                DatabaseName=database_name, TableName=table_name, 
PartitionInput=partition_input
+            )
+        except ClientError as e:
+            self.log.error("Client error: %s", e)
+            raise AirflowException("AWS request failed, check logs for more 
info")
+
 
 class AwsGlueCatalogHook(GlueCatalogHook):
     """
diff --git a/tests/providers/amazon/aws/hooks/test_glue_catalog.py 
b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
index 29730a12e6..adbe3da293 100644
--- a/tests/providers/amazon/aws/hooks/test_glue_catalog.py
+++ b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
@@ -21,7 +21,9 @@ from unittest import mock
 
 import boto3
 import pytest
+from botocore.exceptions import ClientError
 
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
 
 try:
@@ -38,6 +40,9 @@ TABLE_INPUT = {
         "Location": f"s3://mybucket/{DB_NAME}/{TABLE_NAME}",
     },
 }
+PARTITION_INPUT: dict = {
+    "Values": [],
+}
 
 
 @unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is 
not available")
@@ -134,3 +139,54 @@ class TestGlueCatalogHook(unittest.TestCase):
 
         result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
         assert result == TABLE_INPUT['StorageDescriptor']['Location']
+
+    @mock_glue
+    def test_get_partition(self):
+        self.client.create_database(DatabaseInput={'Name': DB_NAME})
+        self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
+        self.client.create_partition(
+            DatabaseName=DB_NAME, TableName=TABLE_NAME, 
PartitionInput=PARTITION_INPUT
+        )
+
+        result = self.hook.get_partition(DB_NAME, TABLE_NAME, 
PARTITION_INPUT['Values'])
+
+        assert result["Values"] == PARTITION_INPUT['Values']
+        assert result["DatabaseName"] == DB_NAME
+        assert result["TableName"] == TABLE_INPUT["Name"]
+
+    @mock_glue
+    @mock.patch.object(GlueCatalogHook, 'get_conn')
+    def test_get_partition_with_client_error(self, mocked_connection):
+        mocked_client = mock.Mock()
+        mocked_client.get_partition.side_effect = ClientError({}, 
"get_partition")
+        mocked_connection.return_value = mocked_client
+
+        with pytest.raises(AirflowException):
+            self.hook.get_partition(DB_NAME, TABLE_NAME, 
PARTITION_INPUT['Values'])
+
+        mocked_client.get_partition.assert_called_once_with(
+            DatabaseName=DB_NAME, TableName=TABLE_NAME, 
PartitionValues=PARTITION_INPUT['Values']
+        )
+
+    @mock_glue
+    def test_create_partition(self):
+        self.client.create_database(DatabaseInput={'Name': DB_NAME})
+        self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
+
+        result = self.hook.create_partition(DB_NAME, TABLE_NAME, 
PARTITION_INPUT)
+
+        assert result
+
+    @mock_glue
+    @mock.patch.object(GlueCatalogHook, 'get_conn')
+    def test_create_partition_with_client_error(self, mocked_connection):
+        mocked_client = mock.Mock()
+        mocked_client.create_partition.side_effect = ClientError({}, 
"create_partition")
+        mocked_connection.return_value = mocked_client
+
+        with pytest.raises(AirflowException):
+            self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)
+
+        mocked_client.create_partition.assert_called_once_with(
+            DatabaseName=DB_NAME, TableName=TABLE_NAME, 
PartitionInput=PARTITION_INPUT
+        )

Reply via email to