This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 94f2ce9342 Add partition related methods to GlueCatalogHook: (#23857)
94f2ce9342 is described below
commit 94f2ce9342d995f1d2eb00e6a9444e57c90e4963
Author: Guilherme Martins Crocetti
<[email protected]>
AuthorDate: Mon May 30 16:26:40 2022 -0300
Add partition related methods to GlueCatalogHook: (#23857)
* "get_partition" to retrieve a Partition
* "create_partition" to create a Partition
---
airflow/providers/amazon/aws/hooks/glue_catalog.py | 58 +++++++++++++++++++++-
.../amazon/aws/hooks/test_glue_catalog.py | 56 +++++++++++++++++++++
2 files changed, 113 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py
b/airflow/providers/amazon/aws/hooks/glue_catalog.py
index fc9c353e08..e77916d09e 100644
--- a/airflow/providers/amazon/aws/hooks/glue_catalog.py
+++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py
@@ -18,8 +18,11 @@
"""This module contains AWS Glue Catalog Hook"""
import warnings
-from typing import Optional, Set
+from typing import Dict, List, Optional, Set
+from botocore.exceptions import ClientError
+
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -123,6 +126,59 @@ class GlueCatalogHook(AwsBaseHook):
return table['StorageDescriptor']['Location']
+ def get_partition(self, database_name: str, table_name: str,
partition_values: List[str]) -> Dict:
+ """
+ Gets a Partition
+
+ :param database_name: Database name
+ :param table_name: Database's Table name
+ :param partition_values: List of utf-8 strings that define the
partition
+ Please see official AWS documentation for further information.
+
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition
+
+ :rtype: dict
+
+ :raises: AirflowException
+
+ >>> hook = GlueCatalogHook()
+ >>> partition = hook.get_partition('db', 'table', ['string'])
+ >>> partition['Values']
+ """
+ try:
+ response = self.get_conn().get_partition(
+ DatabaseName=database_name, TableName=table_name,
PartitionValues=partition_values
+ )
+ return response["Partition"]
+ except ClientError as e:
+ self.log.error("Client error: %s", e)
+ raise AirflowException("AWS request failed, check logs for more
info")
+
+ def create_partition(self, database_name: str, table_name: str,
partition_input: Dict) -> Dict:
+ """
+ Creates a new Partition
+
+ :param database_name: Database name
+ :param table_name: Database's Table name
+ :param partition_input: Definition of how the partition is created
+ Please see official AWS documentation for further information.
+
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition
+
+ :rtype: dict
+
+ :raises: AirflowException
+
+ >>> hook = GlueCatalogHook()
+ >>> partition_input = {"Values": []}
+ >>> hook.create_partition(database_name="db", table_name="table",
partition_input=partition_input)
+ """
+ try:
+ return self.get_conn().create_partition(
+ DatabaseName=database_name, TableName=table_name,
PartitionInput=partition_input
+ )
+ except ClientError as e:
+ self.log.error("Client error: %s", e)
+ raise AirflowException("AWS request failed, check logs for more
info")
+
class AwsGlueCatalogHook(GlueCatalogHook):
"""
diff --git a/tests/providers/amazon/aws/hooks/test_glue_catalog.py
b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
index 29730a12e6..adbe3da293 100644
--- a/tests/providers/amazon/aws/hooks/test_glue_catalog.py
+++ b/tests/providers/amazon/aws/hooks/test_glue_catalog.py
@@ -21,7 +21,9 @@ from unittest import mock
import boto3
import pytest
+from botocore.exceptions import ClientError
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
try:
@@ -38,6 +40,9 @@ TABLE_INPUT = {
"Location": f"s3://mybucket/{DB_NAME}/{TABLE_NAME}",
},
}
+PARTITION_INPUT: dict = {
+ "Values": [],
+}
@unittest.skipIf(mock_glue is None, "Skipping test because moto.mock_glue is
not available")
@@ -134,3 +139,54 @@ class TestGlueCatalogHook(unittest.TestCase):
result = self.hook.get_table_location(DB_NAME, TABLE_NAME)
assert result == TABLE_INPUT['StorageDescriptor']['Location']
+
+ @mock_glue
+ def test_get_partition(self):
+ self.client.create_database(DatabaseInput={'Name': DB_NAME})
+ self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
+ self.client.create_partition(
+ DatabaseName=DB_NAME, TableName=TABLE_NAME,
PartitionInput=PARTITION_INPUT
+ )
+
+ result = self.hook.get_partition(DB_NAME, TABLE_NAME,
PARTITION_INPUT['Values'])
+
+ assert result["Values"] == PARTITION_INPUT['Values']
+ assert result["DatabaseName"] == DB_NAME
+ assert result["TableName"] == TABLE_INPUT["Name"]
+
+ @mock_glue
+ @mock.patch.object(GlueCatalogHook, 'get_conn')
+ def test_get_partition_with_client_error(self, mocked_connection):
+ mocked_client = mock.Mock()
+ mocked_client.get_partition.side_effect = ClientError({},
"get_partition")
+ mocked_connection.return_value = mocked_client
+
+ with pytest.raises(AirflowException):
+ self.hook.get_partition(DB_NAME, TABLE_NAME,
PARTITION_INPUT['Values'])
+
+ mocked_client.get_partition.assert_called_once_with(
+ DatabaseName=DB_NAME, TableName=TABLE_NAME,
PartitionValues=PARTITION_INPUT['Values']
+ )
+
+ @mock_glue
+ def test_create_partition(self):
+ self.client.create_database(DatabaseInput={'Name': DB_NAME})
+ self.client.create_table(DatabaseName=DB_NAME, TableInput=TABLE_INPUT)
+
+ result = self.hook.create_partition(DB_NAME, TABLE_NAME,
PARTITION_INPUT)
+
+ assert result
+
+ @mock_glue
+ @mock.patch.object(GlueCatalogHook, 'get_conn')
+ def test_create_partition_with_client_error(self, mocked_connection):
+ mocked_client = mock.Mock()
+ mocked_client.create_partition.side_effect = ClientError({},
"create_partition")
+ mocked_connection.return_value = mocked_client
+
+ with pytest.raises(AirflowException):
+ self.hook.create_partition(DB_NAME, TABLE_NAME, PARTITION_INPUT)
+
+ mocked_client.create_partition.assert_called_once_with(
+ DatabaseName=DB_NAME, TableName=TABLE_NAME,
PartitionInput=PARTITION_INPUT
+ )