hankehly commented on code in PR #27410: URL: https://github.com/apache/airflow/pull/27410#discussion_r1009678477
########## tests/providers/amazon/aws/hooks/test_rds.py: ########## @@ -17,14 +17,301 @@ # under the License. from __future__ import annotations +from unittest.mock import patch + +import pytest +from moto import mock_rds + +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.rds import RdsHook [email protected] +def rds_hook() -> RdsHook: + """Returns an RdsHook whose underlying connection is mocked with moto""" + with mock_rds(): + yield RdsHook(aws_conn_id="aws_default", region_name="us-east-1") + + [email protected] +def db_instance_id(rds_hook: RdsHook) -> str: + """Creates an RDS DB instance and returns its id""" + response = rds_hook.conn.create_db_instance( + DBInstanceIdentifier="testrdshook-db-instance", + DBInstanceClass="db.t4g.micro", + Engine="postgres", + AllocatedStorage=20, + MasterUsername="testrdshook", + MasterUserPassword="testrdshook", + ) + return response["DBInstance"]["DBInstanceIdentifier"] + + [email protected] +def db_cluster_id(rds_hook: RdsHook) -> str: + """Creates an RDS DB cluster and returns its id""" + response = rds_hook.conn.create_db_cluster( + DBClusterIdentifier="testrdshook-db-cluster", + Engine="postgres", + MasterUsername="testrdshook", + MasterUserPassword="testrdshook", + DBClusterInstanceClass="db.t4g.micro", + AllocatedStorage=20, + ) + return response["DBCluster"]["DBClusterIdentifier"] + + [email protected] +def db_snapshot(rds_hook: RdsHook, db_instance_id: str) -> dict: + """ + Creates a mock DB instance snapshot and returns the DBSnapshot dict from the boto response object. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_snapshot + """ + response = rds_hook.conn.create_db_snapshot( + DBSnapshotIdentifier="testrdshook-db-instance-snapshot", DBInstanceIdentifier=db_instance_id + ) + return response["DBSnapshot"] + + [email protected] +def db_snapshot_id(db_snapshot: dict) -> str: + return db_snapshot["DBSnapshotIdentifier"] + + [email protected] +def db_snapshot_arn(db_snapshot: dict) -> str: + return db_snapshot["DBSnapshotArn"] + + [email protected] +def db_cluster_snapshot(rds_hook: RdsHook, db_cluster_id: str): + """ + Creates a mock DB cluster snapshot and returns the DBClusterSnapshot dict from the boto response object. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_cluster_snapshot + """ + response = rds_hook.conn.create_db_cluster_snapshot( + DBClusterSnapshotIdentifier="testrdshook-db-cluster-snapshot", DBClusterIdentifier=db_cluster_id + ) + return response["DBClusterSnapshot"] + + [email protected] +def db_cluster_snapshot_id(db_cluster_snapshot) -> str: + return db_cluster_snapshot["DBClusterSnapshotIdentifier"] + + [email protected] +def export_task_id(rds_hook: RdsHook, db_snapshot_arn: str) -> str: + response = rds_hook.conn.start_export_task( + ExportTaskIdentifier="testrdshook-export-task", + SourceArn=db_snapshot_arn, + S3BucketName="test", + IamRoleArn="test", + KmsKeyId="test", + ) + return response["ExportTaskIdentifier"] + + [email protected] +def event_subscription_name(rds_hook: RdsHook, db_instance_id: str) -> str: + """Creates an mock RDS event subscription and returns its name""" + response = rds_hook.conn.create_event_subscription( + SubscriptionName="testrdshook-event-subscription", + SnsTopicArn="test", + SourceType="db-instance", + SourceIds=[db_instance_id], + Enabled=True, + ) + return response["EventSubscription"]["CustSubscriptionId"] + + class TestRdsHook: + # For testing, set the delay between status checks to 0 so that we aren't sleeping during tests, + # and max_attempts to 1 so that we don't retry unless required. + waiter_args = {"check_interval": 0, "max_attempts": 1} + def test_conn_attribute(self): hook = RdsHook(aws_conn_id="aws_default", region_name="us-east-1") assert hasattr(hook, "conn") assert hook.conn.__class__.__name__ == "RDS" conn = hook.conn assert conn is hook.conn # Cached property assert conn is hook.get_conn() # Same object as returned by `conn` property + + def test_get_db_instance_state(self, rds_hook: RdsHook, db_instance_id: str): Review Comment: Added test cases for each method in `RdsHook` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
