This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0854500769 Amazon provider docstring improvements (#31729)
0854500769 is described below
commit 0854500769a07f8251269caeab65c95a05d9c28a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jun 8 14:45:29 2023 +0800
Amazon provider docstring improvements (#31729)
---
airflow/providers/amazon/aws/hooks/athena.py | 71 +++---
airflow/providers/amazon/aws/hooks/base_aws.py | 66 +++---
airflow/providers/amazon/aws/hooks/glacier.py | 17 +-
.../providers/amazon/aws/hooks/redshift_cluster.py | 45 ++--
airflow/providers/amazon/aws/hooks/redshift_sql.py | 20 +-
airflow/providers/amazon/aws/hooks/sagemaker.py | 241 ++++++++++-----------
.../providers/amazon/aws/hooks/secrets_manager.py | 13 +-
airflow/providers/amazon/aws/operators/batch.py | 75 +++----
airflow/providers/amazon/aws/operators/glue.py | 13 +-
.../providers/amazon/aws/transfers/mongo_to_s3.py | 27 +--
.../providers/amazon/aws/transfers/s3_to_sql.py | 18 +-
.../amazon/aws/utils/connection_wrapper.py | 28 +--
airflow/providers/amazon/aws/utils/redshift.py | 6 +-
13 files changed, 299 insertions(+), 341 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py
b/airflow/providers/amazon/aws/hooks/athena.py
index 4d1511510a..f68eee9355 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -33,13 +33,15 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
class AthenaHook(AwsBaseHook):
- """
- Interact with Amazon Athena.
- Provide thick wrapper around
:external+boto3:py:class:`boto3.client("athena") <Athena.Client>`.
+ """Interact with Amazon Athena.
+
+ Provide thick wrapper around
+ :external+boto3:py:class:`boto3.client("athena") <Athena.Client>`.
- :param sleep_time: Time (in seconds) to wait between two consecutive calls
to check query status on Athena
- :param log_query: Whether to log athena query and other execution params
when it's executed.
- Defaults to *True*.
+ :param sleep_time: Time (in seconds) to wait between two consecutive calls
+ to check query status on Athena.
+ :param log_query: Whether to log athena query and other execution params
+ when it's executed. Defaults to *True*.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -76,17 +78,20 @@ class AthenaHook(AwsBaseHook):
client_request_token: str | None = None,
workgroup: str = "primary",
) -> str:
- """
- Run Presto query on athena with provided config and return submitted
query_execution_id.
+ """Run a Presto query on Athena with provided config.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.start_query_execution`
- :param query: Presto query to run
- :param query_context: Context in which query need to be run
- :param result_configuration: Dict with path to store results in and
config related to encryption
- :param client_request_token: Unique token created by user to avoid
multiple executions of same query
- :param workgroup: Athena workgroup name, when not specified, will be
'primary'
+ :param query: Presto query to run.
+ :param query_context: Context in which query need to be run.
+ :param result_configuration: Dict with path to store results in and
+ config related to encryption.
+ :param client_request_token: Unique token created by user to avoid
+ multiple executions of same query.
+ :param workgroup: Athena workgroup name, when not specified, will be
+ ``'primary'``.
+ :return: Submitted query execution ID.
"""
params = {
"QueryString": query,
@@ -104,13 +109,14 @@ class AthenaHook(AwsBaseHook):
return query_execution_id
def check_query_status(self, query_execution_id: str) -> str | None:
- """
- Fetch the status of submitted athena query. Returns None or one of
valid query states.
+ """Fetch the state of a submitted query.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.get_query_execution`
:param query_execution_id: Id of submitted athena query
+ :return: One of valid query states, or *None* if the response is
+ malformed.
"""
response =
self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
state = None
@@ -151,10 +157,7 @@ class AthenaHook(AwsBaseHook):
def get_query_results(
self, query_execution_id: str, next_token_id: str | None = None,
max_results: int = 1000
) -> dict | None:
- """
- Fetch submitted athena query results.
-
- Returns none if query is in intermediate state or failed/cancelled
state else dict of query output.
+ """Fetch submitted query results.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.get_query_results`
@@ -162,6 +165,8 @@ class AthenaHook(AwsBaseHook):
:param query_execution_id: Id of submitted athena query
:param next_token_id: The token that specifies where to start
pagination.
:param max_results: The maximum number of results (rows) to return in
this request.
+ :return: *None* if the query is in intermediate, failed, or cancelled
+ state. Otherwise a dict of query outputs.
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
@@ -186,10 +191,7 @@ class AthenaHook(AwsBaseHook):
page_size: int | None = None,
starting_token: str | None = None,
) -> PageIterator | None:
- """
- Fetch submitted athena query results. returns none if query is in
intermediate state or
- failed/cancelled state else a paginator to iterate through pages of
results. If you
- wish to get all results at once, call build_full_result() on the
returned PageIterator.
+ """Fetch submitted Athena query results.
.. seealso::
- :external+boto3:py:class:`Athena.Paginator.GetQueryResults`
@@ -198,6 +200,11 @@ class AthenaHook(AwsBaseHook):
:param max_items: The total number of items to return.
:param page_size: The size of each page.
:param starting_token: A token to specify where to start paginating.
+ :return: *None* if the query is in intermediate, failed, or cancelled
+ state. Otherwise a paginator to iterate through pages of results.
+
+ Call :meth`.build_full_result()` on the returned paginator to get all
+ results at once.
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
@@ -226,13 +233,12 @@ class AthenaHook(AwsBaseHook):
query_execution_id: str,
max_polling_attempts: int | None = None,
) -> str | None:
- """
- Poll the status of submitted athena query until query state reaches
final state.
-
- Returns one of the final states.
+ """Poll the state of a submitted query until it reaches final state.
- :param query_execution_id: Id of submitted athena query
- :param max_polling_attempts: Number of times to poll for query state
before function exits
+ :param query_execution_id: ID of submitted athena query
+ :param max_polling_attempts: Number of times to poll for query state
+ before function exits
+ :return: One of the final states
"""
try_number = 1
final_query_state = None # Query state when query reaches final state
or max_polling_attempts reached
@@ -270,9 +276,7 @@ class AthenaHook(AwsBaseHook):
return final_query_state
def get_output_location(self, query_execution_id: str) -> str:
- """
- Function to get the output location of the query results
- in s3 uri format.
+ """Get the output location of the query results in S3 URI format.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.get_query_execution`
@@ -299,8 +303,7 @@ class AthenaHook(AwsBaseHook):
return output_location
def stop_query(self, query_execution_id: str) -> dict:
- """
- Cancel the submitted athena query.
+ """Cancel the submitted query.
.. seealso::
- :external+boto3:py:meth:`Athena.Client.stop_query_execution`
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 33699e9d30..aaec87dd09 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -70,15 +70,17 @@ if TYPE_CHECKING:
class BaseSessionFactory(LoggingMixin):
- """
- Base AWS Session Factory class to handle synchronous and async boto
session creation.
- It can handle most of the AWS supported authentication methods.
+ """Base AWS Session Factory class.
+
+ This handles synchronous and async boto session creation. It can handle
most
+ of the AWS supported authentication methods.
User can also derive from this class to have full control of boto3 session
creation or to support custom federation.
- Note: Not all features implemented for synchronous sessions are available
for async
- sessions.
+ .. note::
+ Not all features implemented for synchronous sessions are available
+ for async sessions.
.. seealso::
- :ref:`howto/connection:aws:session-factory`
@@ -409,9 +411,9 @@ class BaseSessionFactory(LoggingMixin):
class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
- """
- Generic class for interact with AWS.
- This class provide a thin wrapper around the boto3 python library.
+ """Generic class for interact with AWS.
+
+ This class provide a thin wrapper around the boto3 Python library.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
@@ -473,10 +475,10 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
@staticmethod
def _find_class_name(target_function_name: str) -> str:
- """
- Given a frame off the stack, return the name of the class which made
the call.
- Note: This method may raise a ValueError or an IndexError, but the
calling
- method is catching and handling those.
+ """Given a frame off the stack, return the name of the class that made
the call.
+
+ This method may raise a ValueError or an IndexError. The caller is
+ responsible with catching and handling those.
"""
stack = inspect.stack()
# Find the index of the most recent frame which called the provided
function name.
@@ -504,7 +506,8 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
@staticmethod
def _generate_dag_key() -> str:
- """
+ """Generate a DAG key.
+
The Object Identifier (OID) namespace is used to salt the dag_id value.
That salted value is used to generate a SHA-1 hash which, by
definition,
can not (reasonably) be reversed. No personal data can be inferred or
@@ -711,9 +714,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
return creds
def expand_role(self, role: str, region_name: str | None = None) -> str:
- """
- If the IAM role is a role name, get the Amazon Resource Name (ARN) for
the role.
- If IAM role is already an IAM role ARN, no change is made.
+ """Get the Amazon Resource Name (ARN) for the role.
+
+ If IAM role is already an IAM role ARN, the value is returned
unchanged.
:param role: IAM role name or ARN
:param region_name: Optional region name to get credentials for
@@ -730,10 +733,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
@staticmethod
def retry(should_retry: Callable[[Exception], bool]):
- """
- A decorator that provides a mechanism to repeat requests in response
to exceeding a temporary quote
- limit.
- """
+ """Repeat requests in response to exceeding a temporary quote limit."""
def retry_decorator(fun: Callable):
@wraps(fun)
@@ -789,8 +789,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
}
def test_connection(self):
- """
- Tests the AWS connection by call AWS STS (Security Token Service)
GetCallerIdentity API.
+ """Test the AWS connection by call AWS STS (Security Token Service)
GetCallerIdentity API.
.. seealso::
https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
@@ -824,7 +823,8 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
deferrable: bool = False,
client=None,
) -> Waiter:
- """
+ """Get a waiter by name.
+
First checks if there is a custom waiter with the provided waiter_name
and
uses that if it exists, otherwise it will check the service client for
a
waiter that matches the name and pass that through.
@@ -835,9 +835,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
:param waiter_name: The name of the waiter. The name should exactly
match the
name of the key in the waiter model file (typically this is
CamelCase).
- :param parameters: will scan the waiter config for the keys of that
dict, and replace them with the
- corresponding value. If a custom waiter has such keys to be
expanded, they need to be provided
- here.
+ :param parameters: will scan the waiter config for the keys of that
dict,
+ and replace them with the corresponding value. If a custom waiter
has
+ such keys to be expanded, they need to be provided here.
:param deferrable: If True, the waiter is going to be an async custom
waiter.
An async client must be provided in that case.
:param client: The client to use for the waiter's operations
@@ -904,9 +904,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
- """
- Base class for interact with AWS.
- This class provide a thin wrapper around the boto3 python library.
+ """Base class for interact with AWS.
+
+ This class provide a thin wrapper around the boto3 Python library.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
@@ -979,9 +979,7 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
super().__init__(*args, **kwargs)
async def get_role_credentials(self) -> dict:
- """Get the role_arn, method credentials from connection details and
get the role credentials
- detail.
- """
+ """Get the role_arn, method credentials from connection and get the
role credentials."""
async with self._basic_session.create_client("sts",
region_name=self.region_name) as client:
response = await client.assume_role(
RoleArn=self.role_arn,
@@ -1009,7 +1007,6 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
return credentials
def _get_session_with_assume_role(self) -> AioSession:
-
assume_role_method = self.conn.assume_role_method
if assume_role_method != "assume_role":
raise
NotImplementedError(f"assume_role_method={assume_role_method} not expected")
@@ -1058,8 +1055,7 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
class AwsBaseAsyncHook(AwsBaseHook):
- """
- Interacts with AWS using aiobotocore asynchronously.
+ """Interacts with AWS using aiobotocore asynchronously.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default botocore behaviour is used.
If
diff --git a/airflow/providers/amazon/aws/hooks/glacier.py
b/airflow/providers/amazon/aws/hooks/glacier.py
index cc886a64b7..bd260000e7 100644
--- a/airflow/providers/amazon/aws/hooks/glacier.py
+++ b/airflow/providers/amazon/aws/hooks/glacier.py
@@ -23,9 +23,10 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
class GlacierHook(AwsBaseHook):
- """
- Interact with Amazon Glacier.
- Provide thin wrapper around
:external+boto3:py:class:`boto3.client("glacier") <Glacier.Client>`.
+ """Interact with Amazon Glacier.
+
+ This is a thin wrapper around
+ :external+boto3:py:class:`boto3.client("glacier") <Glacier.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -39,8 +40,7 @@ class GlacierHook(AwsBaseHook):
self.aws_conn_id = aws_conn_id
def retrieve_inventory(self, vault_name: str) -> dict[str, Any]:
- """
- Initiate an Amazon Glacier inventory-retrieval job.
+ """Initiate an Amazon Glacier inventory-retrieval job.
.. seealso::
- :external+boto3:py:meth:`Glacier.Client.initiate_job`
@@ -55,8 +55,7 @@ class GlacierHook(AwsBaseHook):
return response
def retrieve_inventory_results(self, vault_name: str, job_id: str) ->
dict[str, Any]:
- """
- Retrieve the results of an Amazon Glacier inventory-retrieval job.
+ """Retrieve the results of an Amazon Glacier inventory-retrieval job.
.. seealso::
- :external+boto3:py:meth:`Glacier.Client.get_job_output`
@@ -69,9 +68,7 @@ class GlacierHook(AwsBaseHook):
return response
def describe_job(self, vault_name: str, job_id: str) -> dict[str, Any]:
- """
- Retrieve the status of an Amazon S3 Glacier job, such as an
- inventory-retrieval job.
+ """Retrieve the status of an Amazon S3 Glacier job.
.. seealso::
- :external+boto3:py:meth:`Glacier.Client.describe_job`
diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 872c692732..7c44c9ec72 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -28,9 +28,10 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseAsyncHook, AwsBas
class RedshiftHook(AwsBaseHook):
- """
- Interact with Amazon Redshift.
- Provide thin wrapper around
:external+boto3:py:class:`boto3.client("redshift") <Redshift.Client>`.
+ """Interact with Amazon Redshift.
+
+ This is a thin wrapper around
+ :external+boto3:py:class:`boto3.client("redshift") <Redshift.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -53,8 +54,7 @@ class RedshiftHook(AwsBaseHook):
master_user_password: str,
params: dict[str, Any],
) -> dict[str, Any]:
- """
- Creates a new cluster with the specified parameters.
+ """Create a new cluster with the specified parameters.
.. seealso::
- :external+boto3:py:meth:`Redshift.Client.create_cluster`
@@ -84,8 +84,7 @@ class RedshiftHook(AwsBaseHook):
# TODO: Wrap create_cluster_snapshot
def cluster_status(self, cluster_identifier: str) -> str:
- """
- Return status of a cluster.
+ """Get status of a cluster.
.. seealso::
- :external+boto3:py:meth:`Redshift.Client.describe_clusters`
@@ -106,8 +105,7 @@ class RedshiftHook(AwsBaseHook):
skip_final_cluster_snapshot: bool = True,
final_cluster_snapshot_identifier: str | None = None,
):
- """
- Delete a cluster and optionally create a snapshot.
+ """Delete a cluster and optionally create a snapshot.
.. seealso::
- :external+boto3:py:meth:`Redshift.Client.delete_cluster`
@@ -126,8 +124,7 @@ class RedshiftHook(AwsBaseHook):
return response["Cluster"] if response["Cluster"] else None
def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str]
| None:
- """
- Gets a list of snapshots for a cluster.
+ """List snapshots for a cluster.
.. seealso::
-
:external+boto3:py:meth:`Redshift.Client.describe_cluster_snapshots`
@@ -143,8 +140,7 @@ class RedshiftHook(AwsBaseHook):
return snapshots
def restore_from_cluster_snapshot(self, cluster_identifier: str,
snapshot_identifier: str) -> str:
- """
- Restores a cluster from its snapshot.
+ """Restore a cluster from its snapshot.
.. seealso::
-
:external+boto3:py:meth:`Redshift.Client.restore_from_cluster_snapshot`
@@ -164,8 +160,7 @@ class RedshiftHook(AwsBaseHook):
retention_period: int = -1,
tags: list[Any] | None = None,
) -> str:
- """
- Creates a snapshot of a cluster.
+ """Create a snapshot of a cluster.
.. seealso::
- :external+boto3:py:meth:`Redshift.Client.create_cluster_snapshot`
@@ -187,8 +182,9 @@ class RedshiftHook(AwsBaseHook):
return response["Snapshot"] if response["Snapshot"] else None
def get_cluster_snapshot_status(self, snapshot_identifier: str):
- """
- Return Redshift cluster snapshot status. If cluster snapshot not found
return ``None``.
+ """Get Redshift cluster snapshot status.
+
+ If cluster snapshot not found, *None* is returned.
:param snapshot_identifier: A unique identifier for the snapshot that
you are requesting
"""
@@ -217,9 +213,7 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
super().__init__(*args, **kwargs)
async def cluster_status(self, cluster_identifier: str, delete_operation:
bool = False) -> dict[str, Any]:
- """
- Connects to the AWS redshift cluster via aiobotocore and get the status
- and returns the status of the cluster based on the cluster_identifier
passed.
+ """Get the cluster status.
:param cluster_identifier: unique identifier of a cluster
:param delete_operation: whether the method has been called as part of
delete cluster operation
@@ -237,9 +231,7 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
return {"status": "error", "message": str(error)}
async def pause_cluster(self, cluster_identifier: str, poll_interval:
float = 5.0) -> dict[str, Any]:
- """
- Connects to the AWS redshift cluster via aiobotocore and
- pause the cluster based on the cluster_identifier passed.
+ """Pause the cluster.
:param cluster_identifier: unique identifier of a cluster
:param poll_interval: polling period in seconds to check for the status
@@ -266,9 +258,7 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
cluster_identifier: str,
polling_period_seconds: float = 5.0,
) -> dict[str, Any]:
- """
- Connects to the AWS redshift cluster via aiobotocore and
- resume the cluster for the cluster_identifier passed.
+ """Resume the cluster.
:param cluster_identifier: unique identifier of a cluster
:param polling_period_seconds: polling period in seconds to check for
the status
@@ -297,8 +287,7 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
flag: asyncio.Event,
delete_operation: bool = False,
) -> dict[str, Any]:
- """
- check for expected Redshift cluster state.
+ """Check for expected Redshift cluster state.
:param cluster_identifier: unique identifier of a cluster
:param expected_state: expected_state example("available", "pausing",
"paused"")
diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py
b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index afc2e797c9..ffba09b6fc 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -33,8 +33,7 @@ if TYPE_CHECKING:
class RedshiftSQLHook(DbApiHook):
- """
- Execute statements against Amazon Redshift, using redshift_connector.
+ """Execute statements against Amazon Redshift.
This hook requires the redshift_conn_id connection.
@@ -65,7 +64,7 @@ class RedshiftSQLHook(DbApiHook):
@staticmethod
def get_ui_field_behaviour() -> dict:
- """Returns custom field behavior."""
+ """Custom field behavior."""
return {
"hidden_fields": [],
"relabeling": {"login": "User", "schema": "Database"},
@@ -76,7 +75,7 @@ class RedshiftSQLHook(DbApiHook):
return self.get_connection(self.redshift_conn_id) # type:
ignore[attr-defined]
def _get_conn_params(self) -> dict[str, str | int]:
- """Helper method to retrieve connection args."""
+ """Retrieve connection parameters."""
conn = self.conn
conn_params: dict[str, str | int] = {}
@@ -98,8 +97,8 @@ class RedshiftSQLHook(DbApiHook):
return conn_params
def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
- """
- Uses AWSHook to retrieve a temporary password to connect to Redshift.
+ """Retrieve a temporary password to connect to Redshift.
+
Port is required. If none is provided, default is used for each
service.
"""
port = conn.port or 5439
@@ -124,7 +123,7 @@ class RedshiftSQLHook(DbApiHook):
return login, token, port
def get_uri(self) -> str:
- """Overrides DbApiHook get_uri to use redshift_connector sqlalchemy
dialect as driver name."""
+ """Overridden to use the Redshift dialect as driver name."""
conn_params = self._get_conn_params()
if "user" in conn_params:
@@ -136,7 +135,7 @@ class RedshiftSQLHook(DbApiHook):
return str(create_url(drivername="redshift+redshift_connector",
**conn_params))
def get_sqlalchemy_engine(self, engine_kwargs=None):
- """Overrides DbApiHook get_sqlalchemy_engine to pass
redshift_connector specific kwargs."""
+ """Overridden to pass Redshift-specific arguments."""
conn_kwargs = self.conn.extra_dejson
if engine_kwargs is None:
engine_kwargs = {}
@@ -149,8 +148,7 @@ class RedshiftSQLHook(DbApiHook):
return create_engine(self.get_uri(), **engine_kwargs)
def get_table_primary_key(self, table: str, schema: str | None = "public")
-> list[str] | None:
- """
- Helper method that returns the table primary key.
+ """Get the table's primary key.
:param table: Name of the target table
:param schema: Name of the target schema, public by default
@@ -171,7 +169,7 @@ class RedshiftSQLHook(DbApiHook):
return pk_columns or None
def get_conn(self) -> RedshiftConnection:
- """Returns a redshift_connector.Connection object."""
+ """Get a ``redshift_connector.Connection`` object."""
conn_params = self._get_conn_params()
conn_kwargs_dejson = self.conn.extra_dejson
conn_kwargs: dict = {**conn_params, **conn_kwargs_dejson}
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index b6e7dd487d..72354c9c8b 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -39,8 +39,7 @@ from airflow.utils import timezone
class LogState:
- """
- Enum-style class holding all possible states of CloudWatch log streams.
+ """Enum-style class holding all possible states of CloudWatch log streams.
https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState
"""
@@ -58,7 +57,10 @@ Position = collections.namedtuple("Position", ["timestamp",
"skip"])
def argmin(arr, f: Callable) -> int | None:
- """Return the index, i, in arr that minimizes f(arr[i])."""
+ """Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
+
+ None is returned if ``arr`` is empty.
+ """
min_value = None
min_idx = None
for idx, item in enumerate(arr):
@@ -70,8 +72,7 @@ def argmin(arr, f: Callable) -> int | None:
def secondary_training_status_changed(current_job_description: dict,
prev_job_description: dict) -> bool:
- """
- Returns true if training job's secondary status message has changed.
+ """Check if training job's secondary status message has changed.
:param current_job_description: Current job description, returned from
DescribeTrainingJob call.
:param prev_job_description: Previous job description, returned from
DescribeTrainingJob call.
@@ -101,8 +102,7 @@ def
secondary_training_status_changed(current_job_description: dict, prev_job_de
def secondary_training_status_message(
job_description: dict[str, list[Any]], prev_description: dict | None
) -> str:
- """
- Returns a string contains start time and the secondary training job status
message.
+ """Format string containing start time and the secondary training job
status message.
:param job_description: Returned response from DescribeTrainingJob call
:param prev_description: Previous job description from DescribeTrainingJob
call
@@ -136,9 +136,10 @@ def secondary_training_status_message(
class SageMakerHook(AwsBaseHook):
- """
- Interact with Amazon SageMaker.
- Provide thick wrapper around
:external+boto3:py:class:`boto3.client("sagemaker") <SageMaker.Client>`.
+ """Interact with Amazon SageMaker.
+
+ Provide thick wrapper around
+ :external+boto3:py:class:`boto3.client("sagemaker") <SageMaker.Client>`.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -158,13 +159,11 @@ class SageMakerHook(AwsBaseHook):
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None:
- """
- Tar the local file or directory and upload to s3.
+ """Tar the local file or directory and upload to s3.
:param path: local file or directory
:param key: s3 key
:param bucket: s3 bucket
- :return: None
"""
with tempfile.TemporaryFile() as temp_file:
if os.path.isdir(path):
@@ -178,8 +177,7 @@ class SageMakerHook(AwsBaseHook):
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
def configure_s3_resources(self, config: dict) -> None:
- """
- Extract the S3 operations from the configuration and execute them.
+ """Extract the S3 operations from the configuration and execute them.
:param config: config of SageMaker operation
"""
@@ -197,8 +195,7 @@ class SageMakerHook(AwsBaseHook):
self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"])
def check_s3_url(self, s3url: str) -> bool:
- """
- Check if an S3 URL exists.
+ """Check if an S3 URL exists.
:param s3url: S3 url
"""
@@ -219,11 +216,9 @@ class SageMakerHook(AwsBaseHook):
return True
def check_training_config(self, training_config: dict) -> None:
- """
- Check if a training configuration is valid.
+ """Check if a training configuration is valid.
:param training_config: training_config
- :return: None
"""
if "InputDataConfig" in training_config:
for channel in training_config["InputDataConfig"]:
@@ -231,19 +226,18 @@ class SageMakerHook(AwsBaseHook):
self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
def check_tuning_config(self, tuning_config: dict) -> None:
- """
- Check if a tuning configuration is valid.
+ """Check if a tuning configuration is valid.
:param tuning_config: tuning_config
- :return: None
"""
for channel in
tuning_config["TrainingJobDefinition"]["InputDataConfig"]:
if "S3DataSource" in channel["DataSource"]:
self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
def multi_stream_iter(self, log_group: str, streams: list, positions=None)
-> Generator:
- """
- Iterate over the available events coming from a set of log streams in
a single log group
+ """Iterate over the available events.
+
+ The events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in
timestamp order.
:param log_group: The name of the log group.
@@ -284,9 +278,10 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
- Starts a model training job. After training completes, Amazon
SageMaker saves
- the resulting model artifacts to an Amazon S3 location that you
specify.
+ """Start a model training job.
+
+ After training completes, Amazon SageMaker saves the resulting model
+ artifacts to an Amazon S3 location that you specify.
:param config: the config for training
:param wait_for_completion: if the program should keep running until
job finishes
@@ -332,12 +327,13 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
- Starts a hyperparameter tuning job. A hyperparameter tuning job finds
the
- best version of a model by running many training jobs on your dataset
using
- the algorithm you choose and values for hyperparameters within ranges
that
- you specify. It then chooses the hyperparameter values that result in
a model
- that performs the best, as measured by an objective metric that you
choose.
+ """Start a hyperparameter tuning job.
+
+ A hyperparameter tuning job finds the best version of a model by
running
+ many training jobs on your dataset using the algorithm you choose and
+ values for hyperparameters within ranges that you specify. It then
+ chooses the hyperparameter values that result in a model that performs
+ the best, as measured by an objective metric that you choose.
:param config: the config for tuning
:param wait_for_completion: if the program should keep running until
job finishes
@@ -368,9 +364,10 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
- Starts a transform job. A transform job uses a trained model to get
inferences
- on a dataset and saves these results to an Amazon S3 location that you
specify.
+ """Start a transform job.
+
+ A transform job uses a trained model to get inferences on a dataset and
+ saves these results to an Amazon S3 location that you specify.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_transform_job`
@@ -405,11 +402,12 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
- Use Amazon SageMaker Processing to analyze data and evaluate machine
learning
- models on Amazon SageMaker. With Processing, you can use a simplified,
managed
- experience on SageMaker to run your data processing workloads, such as
feature
- engineering, data validation, model evaluation, and model
interpretation.
+ """Use Amazon SageMaker Processing to analyze data and evaluate models.
+
+ With Processing, you can use a simplified, managed experience on
+ SageMaker to run your data processing workloads, such as feature
+ engineering, data validation, model evaluation, and model
+ interpretation.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_processing_job`
@@ -435,11 +433,13 @@ class SageMakerHook(AwsBaseHook):
return response
def create_model(self, config: dict):
- """
- Creates a model in Amazon SageMaker. In the request, you name the
model and
- describe a primary container. For the primary container, you specify
the Docker
- image that contains inference code, artifacts (from prior training),
and a custom
- environment map that the inference code uses when you deploy the model
for predictions.
+ """Create a model in Amazon SageMaker.
+
+ In the request, you name the model and describe a primary container.
For
+ the primary container, you specify the Docker image that contains
+ inference code, artifacts (from prior training), and a custom
+ environment map that the inference code uses when you deploy the model
+ for predictions.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_model`
@@ -450,11 +450,11 @@ class SageMakerHook(AwsBaseHook):
return self.get_conn().create_model(**config)
def create_endpoint_config(self, config: dict):
- """
- Creates an endpoint configuration that Amazon SageMaker hosting
- services uses to deploy models. In the configuration, you identify
- one or more models, created using the CreateModel API, to deploy and
- the resources that you want Amazon SageMaker to provision.
+ """Create an endpoint configuration to deploy models.
+
+ In the configuration, you identify one or more models, created using
the
+ CreateModel API, to deploy and the resources that you want Amazon
+ SageMaker to provision.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_endpoint_config`
@@ -473,14 +473,13 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
+ """Create an endpoint from configuration.
+
When you create a serverless endpoint, SageMaker provisions and manages
the compute resources for you. Then, you can make inference requests to
the endpoint and receive model predictions in response. SageMaker
scales
the compute resources up and down as needed to handle your request
traffic.
- Requires an Endpoint Config.
-
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_endpoint`
-
:class:`airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.create_endpoint`
@@ -513,10 +512,10 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
max_ingestion_time: int | None = None,
):
- """
- Deploys the new EndpointConfig specified in the request, switches to
using
- newly created endpoint, and then deletes resources provisioned for the
- endpoint using the previous EndpointConfig (there is no availability
loss).
+ """Deploy the config in the request and switch to using the new
endpoint.
+
+ Resources provisioned for the endpoint using the previous
EndpointConfig
+ are deleted (there is no availability loss).
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.update_endpoint`
@@ -543,8 +542,7 @@ class SageMakerHook(AwsBaseHook):
return response
def describe_training_job(self, name: str):
- """
- Return the training job info associated with the name.
+ """Get the training job info associated with the name.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.describe_training_job`
@@ -564,7 +562,7 @@ class SageMakerHook(AwsBaseHook):
last_description: dict,
last_describe_job_call: float,
):
- """Return the training job info associated with job_name and print
CloudWatch logs."""
+ """Get the associated training job info and print CloudWatch logs."""
log_group = "/aws/sagemaker/TrainingJobs"
if len(stream_names) < instance_count:
@@ -616,8 +614,7 @@ class SageMakerHook(AwsBaseHook):
return state, last_description, last_describe_job_call
def describe_tuning_job(self, name: str) -> dict:
- """
- Return the tuning job info associated with the name.
+ """Get the tuning job info associated with the name.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.describe_hyper_parameter_tuning_job`
@@ -628,8 +625,7 @@ class SageMakerHook(AwsBaseHook):
return
self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
def describe_model(self, name: str) -> dict:
- """
- Return the SageMaker model info associated with the name.
+ """Get the SageMaker model info associated with the name.
:param name: the name of the SageMaker model
:return: A dict contains all the model info
@@ -637,8 +633,7 @@ class SageMakerHook(AwsBaseHook):
return self.get_conn().describe_model(ModelName=name)
def describe_transform_job(self, name: str) -> dict:
- """
- Return the transform job info associated with the name.
+ """Get the transform job info associated with the name.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.describe_transform_job`
@@ -649,8 +644,7 @@ class SageMakerHook(AwsBaseHook):
return self.get_conn().describe_transform_job(TransformJobName=name)
def describe_processing_job(self, name: str) -> dict:
- """
- Return the processing job info associated with the name.
+ """Get the processing job info associated with the name.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.describe_processing_job`
@@ -661,8 +655,7 @@ class SageMakerHook(AwsBaseHook):
return self.get_conn().describe_processing_job(ProcessingJobName=name)
def describe_endpoint_config(self, name: str) -> dict:
- """
- Return the endpoint config info associated with the name.
+ """Get the endpoint config info associated with the name.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.describe_endpoint_config`
@@ -673,8 +666,7 @@ class SageMakerHook(AwsBaseHook):
return
self.get_conn().describe_endpoint_config(EndpointConfigName=name)
def describe_endpoint(self, name: str) -> dict:
- """
- Returns the description of an endpoint.
+ """Get the description of an endpoint.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.describe_endpoint`
@@ -693,18 +685,18 @@ class SageMakerHook(AwsBaseHook):
max_ingestion_time: int | None = None,
non_terminal_states: set | None = None,
) -> dict:
- """
- Check status of a SageMaker resource.
+ """Check status of a SageMaker resource.
- :param job_name: name of the resource to check status, can be a job
but also pipeline for instance.
+ :param job_name: name of the resource to check status, can be a job but
+ also pipeline for instance.
:param key: the key of the response dict that points to the state
:param describe_function: the function used to retrieve the status
:param args: the arguments for the function
:param check_interval: the time interval in seconds which the operator
will check the status of any SageMaker resource
:param max_ingestion_time: the maximum ingestion time in seconds. Any
- SageMaker resources that run longer than this will fail. Setting
this to
- None implies no timeout for any SageMaker resource.
+ SageMaker resources that run longer than this will fail. Setting
+ this to None implies no timeout for any SageMaker resource.
:param non_terminal_states: the set of nonterminal states
:return: response of describe call after resource is done
"""
@@ -747,9 +739,9 @@ class SageMakerHook(AwsBaseHook):
check_interval: int,
max_ingestion_time: int | None = None,
):
- """
- Display the logs for a given training job, optionally tailing them
until the
- job is complete.
+ """Display logs for a given training job.
+
+ Optionally tailing them until the job is complete.
:param job_name: name of the training job to check status and display
logs for
:param non_terminal_states: the set of non_terminal states
@@ -760,7 +752,6 @@ class SageMakerHook(AwsBaseHook):
:param max_ingestion_time: the maximum ingestion time in seconds. Any
SageMaker jobs that run longer than this will fail. Setting this to
None implies no timeout for any SageMaker job.
- :return: None
"""
sec = 0
description = self.describe_training_job(job_name)
@@ -831,10 +822,11 @@ class SageMakerHook(AwsBaseHook):
def list_training_jobs(
self, name_contains: str | None = None, max_results: int | None =
None, **kwargs
) -> list[dict]:
- """
- This method wraps boto3's `list_training_jobs`. The training job name
and max results are configurable
- via arguments. Other arguments are not, and should be provided via
kwargs. Note boto3 expects these in
- CamelCase format, for example.
+ """Call boto3's ``list_training_jobs``.
+
+ The training job name and max results are configurable via arguments.
+ Other arguments are not, and should be provided via kwargs. Note that
+ boto3 expects these in CamelCase, for example:
.. code-block:: python
@@ -858,11 +850,11 @@ class SageMakerHook(AwsBaseHook):
def list_transform_jobs(
self, name_contains: str | None = None, max_results: int | None =
None, **kwargs
) -> list[dict]:
- """
- This method wraps boto3's `list_transform_jobs`.
+ """Call boto3's ``list_transform_jobs``.
+
The transform job name and max results are configurable via arguments.
- Other arguments are not, and should be provided via kwargs. Note boto3
expects these in
- CamelCase format, for example.
+ Other arguments are not, and should be provided via kwargs. Note that
+ boto3 expects these in CamelCase, for example:
.. code-block:: python
@@ -871,10 +863,11 @@ class SageMakerHook(AwsBaseHook):
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.list_transform_jobs`
- :param name_contains: (optional) partial name to match
- :param max_results: (optional) maximum number of results to return.
None returns infinite results
- :param kwargs: (optional) kwargs to boto3's list_transform_jobs method
- :return: results of the list_transform_jobs request
+ :param name_contains: (optional) partial name to match.
+ :param max_results: (optional) maximum number of results to return.
+ None returns infinite results.
+ :param kwargs: (optional) kwargs to boto3's list_transform_jobs method.
+ :return: results of the list_transform_jobs request.
"""
config, max_results =
self._preprocess_list_request_args(name_contains, max_results, **kwargs)
list_transform_jobs_request =
partial(self.get_conn().list_transform_jobs, **config)
@@ -884,9 +877,10 @@ class SageMakerHook(AwsBaseHook):
return results
def list_processing_jobs(self, **kwargs) -> list[dict]:
- """
- This method wraps boto3's `list_processing_jobs`. All arguments should
be provided via kwargs.
- Note boto3 expects these in CamelCase format, for example.
+ """Call boto3's `list_processing_jobs`.
+
+ All arguments should be provided via kwargs. Note that boto3 expects
+ these in CamelCase, for example:
.. code-block:: python
@@ -907,15 +901,17 @@ class SageMakerHook(AwsBaseHook):
def _preprocess_list_request_args(
self, name_contains: str | None = None, max_results: int | None =
None, **kwargs
) -> tuple[dict[str, Any], int | None]:
- """
- This method preprocesses the arguments to the boto3's list_* methods.
- It will turn arguments name_contains and max_results as boto3
compliant CamelCase format.
- This method also makes sure that these two arguments are only set once.
+ """Preprocess arguments for boto3's ``list_*`` methods.
+
+ It will turn arguments name_contains and max_results as boto3 compliant
+ CamelCase format. This method also makes sure that these two arguments
+ are only set once.
:param name_contains: boto3 function with arguments
:param max_results: the result key to iterate over
:param kwargs: (optional) kwargs to boto3's list_* method
- :return: Tuple with config dict to be passed to boto3's list_* method
and max_results parameter
+ :return: Tuple with config dict to be passed to boto3's list_* method
+ and max_results parameter
"""
config = {}
@@ -938,14 +934,16 @@ class SageMakerHook(AwsBaseHook):
def _list_request(
self, partial_func: Callable, result_key: str, max_results: int | None
= None
) -> list[dict]:
- """
- All AWS boto3 list_* requests return results in batches (if the key
"NextToken" is contained in the
- result, there are more results to fetch). The default AWS batch size
is 10, and configurable up to
- 100. This function iteratively loads all results (or up to a given
maximum).
+ """Process a list request to produce results.
- Each boto3 list_* function returns the results in a list with a
different name. The key of this
- structure must be given to iterate over the results, e.g.
"TransformJobSummaries" for
- list_transform_jobs().
+ All AWS boto3 ``list_*`` requests return results in batches, and if the
+ key "NextToken" is contained in the result, there are more results to
+ fetch. The default AWS batch size is 10, and configurable up to 100.
+ This function iteratively loads all results (or up to a given maximum).
+
+ Each boto3 ``list_*`` function returns the results in a list with a
+ different name. The key of this structure must be given to iterate over
+ the results, e.g. "TransformJobSummaries" for
``list_transform_jobs()``.
:param partial_func: boto3 function with arguments
:param result_key: the result key to iterate over
@@ -993,8 +991,8 @@ class SageMakerHook(AwsBaseHook):
throttle_retry_delay: int = 2,
retries: int = 3,
) -> int:
- """
- Returns the number of processing jobs found with the provided name
prefix.
+ """Get the number of processing jobs found with the provided name
prefix.
+
:param processing_job_name: The prefix to look for.
:param job_name_suffix: The optional suffix which may be appended to
deduplicate an existing job name.
:param throttle_retry_delay: Seconds to wait if a ThrottlingException
is hit.
@@ -1023,8 +1021,7 @@ class SageMakerHook(AwsBaseHook):
raise
def delete_model(self, model_name: str):
- """
- Delete SageMaker model.
+ """Delete a SageMaker model.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.delete_model`
@@ -1038,8 +1035,7 @@ class SageMakerHook(AwsBaseHook):
raise
def describe_pipeline_exec(self, pipeline_exec_arn: str, verbose: bool =
False):
- """
- Get info about a SageMaker pipeline execution.
+ """Get info about a SageMaker pipeline execution.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.describe_pipeline_execution`
@@ -1068,8 +1064,7 @@ class SageMakerHook(AwsBaseHook):
check_interval: int = 30,
verbose: bool = True,
) -> str:
- """
- Start a new execution for a SageMaker pipeline.
+ """Start a new execution for a SageMaker pipeline.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.start_pipeline_execution`
@@ -1182,8 +1177,7 @@ class SageMakerHook(AwsBaseHook):
return res["PipelineExecutionStatus"]
def create_model_package_group(self, package_group_name: str,
package_group_desc: str = "") -> bool:
- """
- Creates a Model Package Group if it does not already exist.
+ """Create a Model Package Group if it does not already exist.
.. seealso::
-
:external+boto3:py:meth:`SageMaker.Client.create_model_package_group`
@@ -1236,9 +1230,10 @@ class SageMakerHook(AwsBaseHook):
wait_for_completion: bool = True,
check_interval: int = 30,
) -> dict | None:
- """
- Creates an auto ML job, learning to predict the given column from the
data provided through S3.
- The learning output is written to the specified S3 location.
+ """Create an auto ML job to predict the given column.
+
+ The learning input is based on data provided through S3 , and the
output
+ is written to the specified S3 location.
.. seealso::
- :external+boto3:py:meth:`SageMaker.Client.create_auto_ml_job`
diff --git a/airflow/providers/amazon/aws/hooks/secrets_manager.py
b/airflow/providers/amazon/aws/hooks/secrets_manager.py
index c82d543b0c..1a3d10c69a 100644
--- a/airflow/providers/amazon/aws/hooks/secrets_manager.py
+++ b/airflow/providers/amazon/aws/hooks/secrets_manager.py
@@ -24,8 +24,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
class SecretsManagerHook(AwsBaseHook):
- """
- Interact with Amazon SecretsManager Service.
+ """Interact with Amazon SecretsManager Service.
+
Provide thin wrapper around
:external+boto3:py:class:`boto3.client("secretsmanager")
<SecretsManager.Client>`.
@@ -40,9 +40,9 @@ class SecretsManagerHook(AwsBaseHook):
super().__init__(client_type="secretsmanager", *args, **kwargs)
def get_secret(self, secret_name: str) -> str | bytes:
- """
- Retrieve secret value from AWS Secrets Manager as a str or bytes
- reflecting format it stored in the AWS Secrets Manager.
+ """Retrieve secret value from AWS Secrets Manager as a str or bytes.
+
+ The value reflects format it stored in the AWS Secrets Manager.
.. seealso::
- :external+boto3:py:meth:`SecretsManager.Client.get_secret_value`
@@ -60,8 +60,7 @@ class SecretsManagerHook(AwsBaseHook):
return secret
def get_secret_as_dict(self, secret_name: str) -> dict:
- """
- Retrieve secret value from AWS Secrets Manager in a dict
representation.
+ """Retrieve secret value from AWS Secrets Manager as a dict.
:param secret_name: name of the secrets.
:return: dict with the information about the secrets
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index 8a127dbd67..2825ed5a01 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-An Airflow operator for AWS Batch services.
+"""AWS Batch services.
.. seealso::
@@ -46,8 +45,7 @@ if TYPE_CHECKING:
class BatchOperator(BaseOperator):
- """
- Execute a job on AWS Batch.
+ """Execute a job on AWS Batch.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -148,8 +146,7 @@ class BatchOperator(BaseOperator):
deferrable: bool = False,
poll_interval: int = 30,
**kwargs,
- ):
-
+ ) -> None:
BaseOperator.__init__(self, **kwargs)
self.job_id = job_id
self.job_name = job_name
@@ -199,8 +196,7 @@ class BatchOperator(BaseOperator):
)
def execute(self, context: Context):
- """
- Submit and monitor an AWS Batch job.
+ """Submit and monitor an AWS Batch job.
:raises: AirflowException
"""
@@ -236,8 +232,7 @@ class BatchOperator(BaseOperator):
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id,
response)
def submit_job(self, context: Context):
- """
- Submit an AWS Batch job.
+ """Submit an AWS Batch job.
:raises: AirflowException
"""
@@ -288,13 +283,10 @@ class BatchOperator(BaseOperator):
)
def monitor_job(self, context: Context):
- """
- Monitor an AWS Batch job
- monitor_job can raise an exception or an AirflowTaskTimeout can be
raised if execution_timeout
- is given while creating the task. These exceptions should be handled
in taskinstance.py
- instead of here like it was previously done.
+ """Monitor an AWS Batch job.
- :raises: AirflowException
+ This can raise an exception or an AirflowTaskTimeout if the task was
+ created with ``execution_timeout``.
"""
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")
@@ -357,43 +349,34 @@ class BatchOperator(BaseOperator):
class BatchCreateComputeEnvironmentOperator(BaseOperator):
- """
- Create an AWS Batch compute environment.
+ """Create an AWS Batch compute environment.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
:ref:`howto/operator:BatchCreateComputeEnvironmentOperator`
- :param compute_environment_name: the name of the AWS batch compute
environment (templated)
-
- :param environment_type: the type of the compute-environment
-
- :param state: the state of the compute-environment
-
- :param compute_resources: details about the resources managed by the
compute-environment (templated).
- See more details here
+ :param compute_environment_name: Name of the AWS batch compute
+ environment (templated).
+ :param environment_type: Type of the compute-environment.
+ :param state: State of the compute-environment.
+ :param compute_resources: Details about the resources managed by the
+ compute-environment (templated). More details:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.create_compute_environment
-
- :param unmanaged_v_cpus: the maximum number of vCPU for an unmanaged
compute environment.
- This parameter is only supported when the ``type`` parameter is set to
``UNMANAGED``.
-
- :param service_role: the IAM role that allows Batch to make calls to other
AWS services on your behalf
- (templated)
-
- :param tags: the tags that you apply to the compute-environment to help
you categorize and organize your
- resources
-
- :param max_retries: exponential back-off retries, 4200 = 48 hours;
- polling is only used when waiters is None
-
- :param status_retries: number of HTTP retries to get job status, 10;
- polling is only used when waiters is None
-
- :param aws_conn_id: connection id of AWS credentials / region name. If
None,
+ :param unmanaged_v_cpus: Maximum number of vCPU for an unmanaged compute
+ environment. This parameter is only supported when the ``type``
+ parameter is set to ``UNMANAGED``.
+ :param service_role: IAM role that allows Batch to make calls to other AWS
+ services on your behalf (templated).
+ :param tags: Tags that you apply to the compute-environment to help you
+ categorize and organize your resources.
+ :param max_retries: Exponential back-off retries, 4200 = 48 hours; polling
+ is only used when waiters is None.
+ :param status_retries: Number of HTTP retries to get job status, 10;
polling
+ is only used when waiters is None.
+ :param aws_conn_id: Connection ID of AWS credentials / region name. If
None,
credential boto3 strategy will be used.
-
- :param region_name: region name to use in AWS Hook.
- Override the region_name in connection (if provided)
+ :param region_name: Region name to use in AWS Hook. Overrides the
+ ``region_name`` in connection if provided.
"""
template_fields: Sequence[str] = (
diff --git a/airflow/providers/amazon/aws/operators/glue.py
b/airflow/providers/amazon/aws/operators/glue.py
index 5134dbbe70..37010b6fd8 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -33,10 +33,10 @@ if TYPE_CHECKING:
class GlueJobOperator(BaseOperator):
- """
- Creates an AWS Glue Job. AWS Glue is a serverless Spark
- ETL service for running Spark Jobs on the AWS cloud.
- Language support: Python and Scala.
+ """Create an AWS Glue Job.
+
+ AWS Glue is a serverless Spark ETL service for running Spark Jobs on the
AWS
+ cloud. Language support: Python and Scala.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -123,10 +123,9 @@ class GlueJobOperator(BaseOperator):
self.deferrable = deferrable
def execute(self, context: Context):
- """
- Executes AWS Glue Job from Airflow.
+ """Execute AWS Glue Job from Airflow.
- :return: the id of the current glue job.
+ :return: the current Glue job ID.
"""
if self.script_location is None:
s3_script_location = None
diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index d7432c3959..dcfd0d6c70 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
class MongoToS3Operator(BaseOperator):
- """Operator meant to move data from mongo via pymongo to s3 via boto.
+ """Move data from MongoDB to S3.
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -127,23 +127,24 @@ class MongoToS3Operator(BaseOperator):
@staticmethod
def _stringify(iterable: Iterable, joinable: str = "\n") -> str:
+ """Stringify an iterable of dicts.
+
+ This dumps each dict with JSON, and joins them with ``joinable``.
"""
- Takes an iterable (pymongo Cursor or Array) containing dictionaries and
- returns a stringified version using python join.
- """
- return joinable.join([json.dumps(doc, default=json_util.default) for
doc in iterable])
+ return joinable.join(json.dumps(doc, default=json_util.default) for
doc in iterable)
@staticmethod
def transform(docs: Any) -> Any:
- """This method is meant to be extended by child classes
- to perform transformations unique to those operators needs.
- Processes pyMongo cursor and returns an iterable with each element
being
- a JSON serializable dictionary.
+ """Transform the data for transfer.
+
+ This method is meant to be extended by child classes to perform
+ transformations unique to those operators needs. Processes pyMongo
+ cursor and returns an iterable with each element being a JSON
+ serializable dictionary
- Base transform() assumes no processing is needed
- ie. docs is a pyMongo cursor of documents and cursor just
- needs to be passed through
+ The default implementation assumes no processing is needed, i.e. input
+ is a pyMongo cursor of documents and just needs to be passed through.
- Override this method for custom transformations
+ Override this method for custom transformations.
"""
return docs
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sql.py
b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
index 916ba183d7..8e0613ea6d 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_sql.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
@@ -30,10 +30,10 @@ if TYPE_CHECKING:
class S3ToSqlOperator(BaseOperator):
- """
- Loads Data from S3 into a SQL Database.
- You need to provide a parser function that takes a filename as an input
- and returns an iterable of rows.
+ """Load Data from S3 into a SQL Database.
+
+ You need to provide a parser function that takes a filename as an input
+ and returns an iterable of rows
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -52,12 +52,13 @@ class S3ToSqlOperator(BaseOperator):
e.g. to use a CSV parser that yields rows line-by-line, pass the
following
function:
- def parse_csv(filepath):
- import csv
+ .. code-block:: python
- with open(filepath, newline="") as file:
- yield from csv.reader(file)
+ def parse_csv(filepath):
+ import csv
+ with open(filepath, newline="") as file:
+ yield from csv.reader(file)
"""
template_fields: Sequence[str] = (
@@ -97,7 +98,6 @@ class S3ToSqlOperator(BaseOperator):
self.parser = parser
def execute(self, context: Context) -> None:
-
self.log.info("Loading %s to SQL table %s...", self.s3_key, self.table)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py
b/airflow/providers/amazon/aws/utils/connection_wrapper.py
index 1bbc873e9c..1520c3e1fc 100644
--- a/airflow/providers/amazon/aws/utils/connection_wrapper.py
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -75,11 +75,11 @@ class _ConnectionMetadata:
@dataclass
class AwsConnectionWrapper(LoggingMixin):
- """
- AWS Connection Wrapper class helper.
+ """AWS Connection Wrapper class helper.
+
Use for validate and resolve AWS Connection parameters.
- ``conn`` reference to Airflow Connection object or AwsConnectionWrapper
+ ``conn`` references an Airflow Connection object or AwsConnectionWrapper
if it set to ``None`` than default values would use.
The precedence rules for ``region_name``
@@ -319,17 +319,17 @@ class AwsConnectionWrapper(LoggingMixin):
session_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> tuple[str | None, str | None, str | None]:
- """
- Get AWS credentials from connection login/password and extra.
+ """Get AWS credentials from connection login/password and extra.
- ``aws_access_key_id`` and ``aws_secret_access_key`` order
- 1. From Connection login and password
- 2. From Connection extra['aws_access_key_id'] and
extra['aws_access_key_id']
- 3. (deprecated) Form Connection extra['session_kwargs']
- 4. (deprecated) From local credentials file
+ ``aws_access_key_id`` and ``aws_secret_access_key`` order:
- Get ``aws_session_token`` from extra['aws_access_key_id']
+ 1. From Connection login and password
+ 2. From Connection ``extra['aws_access_key_id']`` and
+ ``extra['aws_access_key_id']``
+ 3. (deprecated) Form Connection ``extra['session_kwargs']``
+ 4. (deprecated) From a local credentials file
+ Get ``aws_session_token`` from ``extra['aws_access_key_id']``.
"""
session_kwargs = session_kwargs or {}
session_aws_access_key_id = session_kwargs.get("aws_access_key_id")
@@ -427,9 +427,9 @@ class AwsConnectionWrapper(LoggingMixin):
def _parse_s3_config(
config_file_name: str, config_format: str | None = "boto", profile: str |
None = None
) -> tuple[str | None, str | None]:
- """
- Parses a config file for s3 credentials. Can currently
- parse boto, s3cmd.conf and AWS SDK config formats.
+ """Parse a config file for S3 credentials.
+
+ Can currently parse boto, s3cmd.conf and AWS SDK config formats.
:param config_file_name: path to the config file
:param config_format: config type. One of "boto", "s3cmd" or "aws".
diff --git a/airflow/providers/amazon/aws/utils/redshift.py
b/airflow/providers/amazon/aws/utils/redshift.py
index 1ef490422d..d91858718f 100644
--- a/airflow/providers/amazon/aws/utils/redshift.py
+++ b/airflow/providers/amazon/aws/utils/redshift.py
@@ -24,14 +24,12 @@ log = logging.getLogger(__name__)
def build_credentials_block(credentials: ReadOnlyCredentials) -> str:
- """
- Generate AWS credentials block for Redshift COPY and UNLOAD
- commands, as noted in AWS docs.
+ """Generate AWS credentials block for Redshift COPY and UNLOAD commands.
+ See AWS docs for details:
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
:param credentials: ReadOnlyCredentials object from `botocore`
- :return: str
"""
if credentials.token:
log.debug("STS token found in credentials, including it in the
command")