This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new b9d677c Add type hints to aws provider (#11531)
b9d677c is described below
commit b9d677cdd660e0be8278a64658e73359276a9682
Author: Katsunori Kanda <[email protected]>
AuthorDate: Thu Oct 22 09:49:22 2020 +0900
Add type hints to aws provider (#11531)
* Added type hints to aws provider
* Update airflow/providers/amazon/aws/log/s3_task_handler.py
* Fix expectation for submit_job
* Fix documentation
Co-authored-by: Kamil BreguĊa <[email protected]>
---
.../aws/hooks/elasticache_replication_group.py | 45 +++++-----
airflow/providers/amazon/aws/hooks/glue.py | 4 +-
airflow/providers/amazon/aws/hooks/sagemaker.py | 66 +++++++--------
.../amazon/aws/log/cloudwatch_task_handler.py | 4 +-
.../providers/amazon/aws/log/s3_task_handler.py | 11 ++-
airflow/providers/amazon/aws/operators/batch.py | 39 +++++----
.../amazon/aws/operators/cloud_formation.py | 9 +-
airflow/providers/amazon/aws/operators/datasync.py | 97 +++++++++++++---------
airflow/providers/amazon/aws/operators/ecs.py | 77 +++++++++--------
airflow/providers/amazon/aws/operators/glue.py | 29 +++----
.../providers/amazon/aws/operators/s3_bucket.py | 4 +-
.../amazon/aws/operators/s3_copy_object.py | 15 ++--
.../amazon/aws/operators/s3_delete_objects.py | 12 ++-
airflow/providers/amazon/aws/operators/s3_list.py | 13 ++-
.../amazon/aws/operators/sagemaker_base.py | 11 ++-
.../amazon/aws/operators/sagemaker_endpoint.py | 17 ++--
.../aws/operators/sagemaker_endpoint_config.py | 4 +-
.../amazon/aws/operators/sagemaker_model.py | 4 +-
.../amazon/aws/operators/sagemaker_processing.py | 19 +++--
.../amazon/aws/operators/sagemaker_training.py | 15 ++--
.../amazon/aws/operators/sagemaker_transform.py | 17 ++--
.../amazon/aws/operators/sagemaker_tuning.py | 13 ++-
airflow/providers/amazon/aws/operators/sns.py | 11 +--
airflow/providers/amazon/aws/operators/sqs.py | 11 +--
.../step_function_get_execution_output.py | 10 ++-
.../aws/operators/step_function_start_execution.py | 4 +-
.../amazon/aws/sensors/cloud_formation.py | 21 +++--
airflow/providers/amazon/aws/sensors/emr_base.py | 18 ++--
airflow/providers/amazon/aws/sensors/glue.py | 2 +-
.../amazon/aws/sensors/glue_catalog_partition.py | 23 ++---
airflow/providers/amazon/aws/sensors/redshift.py | 20 +++--
airflow/providers/amazon/aws/sensors/s3_key.py | 23 ++---
airflow/providers/amazon/aws/sensors/s3_prefix.py | 20 +++--
.../providers/amazon/aws/sensors/sagemaker_base.py | 24 +++---
.../amazon/aws/sensors/sagemaker_training.py | 7 +-
.../amazon/aws/sensors/sagemaker_transform.py | 2 +-
.../amazon/aws/sensors/sagemaker_tuning.py | 2 +-
airflow/providers/amazon/aws/sensors/sqs.py | 19 +++--
.../amazon/aws/sensors/step_function_execution.py | 20 +++--
tests/providers/amazon/aws/operators/test_batch.py | 6 +-
40 files changed, 459 insertions(+), 309 deletions(-)
diff --git
a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
index d1c7409..54305d5 100644
--- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
+++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from time import sleep
@@ -40,15 +41,21 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
TERMINAL_STATES = frozenset({"available", "create-failed", "deleting"})
def __init__(
- self, max_retries=10, exponential_back_off_factor=1,
initial_poke_interval=60, *args, **kwargs
+ self,
+ max_retries: int = 10,
+ exponential_back_off_factor: float = 1,
+ initial_poke_interval: float = 60,
+ *args,
+ **kwargs,
):
self.max_retries = max_retries
self.exponential_back_off_factor = exponential_back_off_factor
self.initial_poke_interval = initial_poke_interval
- super().__init__(client_type='elasticache', *args, **kwargs)
+ kwargs["client_type"] = "elasticache"
+ super().__init__(*args, **kwargs)
- def create_replication_group(self, config):
+ def create_replication_group(self, config: dict) -> dict:
"""
Call ElastiCache API for creating a replication group
@@ -59,7 +66,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
"""
return self.conn.create_replication_group(**config)
- def delete_replication_group(self, replication_group_id):
+ def delete_replication_group(self, replication_group_id: str) -> dict:
"""
Call ElastiCache API for deleting a replication group
@@ -70,7 +77,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
"""
return
self.conn.delete_replication_group(ReplicationGroupId=replication_group_id)
- def describe_replication_group(self, replication_group_id):
+ def describe_replication_group(self, replication_group_id: str) -> dict:
"""
Call ElastiCache API for describing a replication group
@@ -81,7 +88,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
"""
return
self.conn.describe_replication_groups(ReplicationGroupId=replication_group_id)
- def get_replication_group_status(self, replication_group_id):
+ def get_replication_group_status(self, replication_group_id: str) -> str:
"""
Get current status of replication group
@@ -92,7 +99,7 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
"""
return
self.describe_replication_group(replication_group_id)['ReplicationGroups'][0]['Status']
- def is_replication_group_available(self, replication_group_id):
+ def is_replication_group_available(self, replication_group_id: str) ->
bool:
"""
Helper for checking if replication group is available or not
@@ -105,10 +112,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
def wait_for_availability(
self,
- replication_group_id,
- initial_sleep_time=None,
- exponential_back_off_factor=None,
- max_retries=None,
+ replication_group_id: str,
+ initial_sleep_time: Optional[float] = None,
+ exponential_back_off_factor: Optional[float] = None,
+ max_retries: Optional[int] = None,
):
"""
Check if replication group is available or not by performing a
describe over it
@@ -164,10 +171,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
def wait_for_deletion(
self,
- replication_group_id,
- initial_sleep_time=None,
- exponential_back_off_factor=None,
- max_retries=None,
+ replication_group_id: str,
+ initial_sleep_time: Optional[float] = None,
+ exponential_back_off_factor: Optional[float] = None,
+ max_retries: Optional[int] = None,
):
"""
Helper for deleting a replication group ensuring it is either deleted
or can't be deleted
@@ -244,10 +251,10 @@ class ElastiCacheReplicationGroupHook(AwsBaseHook):
def ensure_delete_replication_group(
self,
- replication_group_id,
- initial_sleep_time=None,
- exponential_back_off_factor=None,
- max_retries=None,
+ replication_group_id: str,
+ initial_sleep_time: Optional[float] = None,
+ exponential_back_off_factor: Optional[float] = None,
+ max_retries: Optional[int] = None,
):
"""
Delete a replication group ensuring it is either deleted or can't be
deleted
diff --git a/airflow/providers/amazon/aws/hooks/glue.py
b/airflow/providers/amazon/aws/hooks/glue.py
index 8bc2e72..dbc8707 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -93,14 +93,14 @@ class AwsGlueJobHook(AwsBaseHook):
self.log.error("Failed to create aws glue job, error: %s",
general_error)
raise
- def initialize_job(self, script_arguments: Optional[List] = None) ->
Dict[str, str]:
+ def initialize_job(self, script_arguments: Optional[dict] = None) ->
Dict[str, str]:
"""
Initializes connection with AWS Glue
to run job
:return:
"""
glue_client = self.get_conn()
- script_arguments = script_arguments or []
+ script_arguments = script_arguments or {}
try:
job_name = self.get_or_create_glue_job()
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 9009967..af2733d 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -22,7 +22,7 @@ import tempfile
import time
import warnings
from functools import partial
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional, Set, Any, Callable, Generator
from botocore.exceptions import ClientError
@@ -51,7 +51,7 @@ class LogState:
Position = collections.namedtuple('Position', ['timestamp', 'skip'])
-def argmin(arr, f) -> Optional[int]:
+def argmin(arr, f: Callable) -> Optional[int]:
"""Return the index, i, in arr that minimizes f(arr[i])"""
min_value = None
min_idx = None
@@ -94,7 +94,9 @@ def
secondary_training_status_changed(current_job_description: dict, prev_job_de
return message != last_message
-def secondary_training_status_message(job_description, prev_description):
+def secondary_training_status_message(
+ job_description: Dict[str, List[dict]], prev_description: Optional[dict]
+) -> str:
"""
Returns a string contains start time and the secondary training job status
message.
@@ -105,22 +107,14 @@ def secondary_training_status_message(job_description,
prev_description):
:return: Job status string to be printed.
"""
- if (
- job_description is None
- or job_description.get('SecondaryStatusTransitions') is None
- or len(job_description.get('SecondaryStatusTransitions')) == 0
- ):
+ current_transitions = job_description.get('SecondaryStatusTransitions')
+ if current_transitions is None or len(current_transitions) == 0:
return ''
- prev_description_secondary_transitions = (
- prev_description.get('SecondaryStatusTransitions') if prev_description
is not None else None
- )
- prev_transitions_num = (
- len(prev_description['SecondaryStatusTransitions'])
- if prev_description_secondary_transitions is not None
- else 0
- )
- current_transitions = job_description['SecondaryStatusTransitions']
+ prev_transitions_num = 0
+ if prev_description is not None:
+ if prev_description.get('SecondaryStatusTransitions') is not None:
+ prev_transitions_num =
len(prev_description['SecondaryStatusTransitions'])
transitions_to_print = (
current_transitions[-1:]
@@ -278,7 +272,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
return self.logs_hook.get_log_events(log_group, stream_name,
start_time, skip)
- def multi_stream_iter(self, log_group, streams, positions=None):
+ def multi_stream_iter(self, log_group: str, streams: list, positions=None)
-> Generator:
"""
Iterate over the available events coming from a set of log streams in
a single log group
interleaving the events from each stream so they're yielded in
timestamp order.
@@ -298,7 +292,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
self.logs_hook.get_log_events(log_group, s,
positions[s].timestamp, positions[s].skip)
for s in streams
]
- events = []
+ events: List[Optional[Any]] = []
for event_stream in event_iters:
if not event_stream:
events.append(None)
@@ -309,8 +303,8 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
events.append(None)
while any(events):
- i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
- yield (i, events[i])
+ i = argmin(events, lambda x: x['timestamp'] if x else 9999999999)
or 0
+ yield i, events[i]
try:
events[i] = next(event_iters[i])
except StopIteration:
@@ -576,13 +570,13 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
def describe_training_job_with_log(
self,
- job_name,
+ job_name: str,
positions,
- stream_names,
- instance_count,
- state,
- last_description,
- last_describe_job_call,
+ stream_names: list,
+ instance_count: int,
+ state: int,
+ last_description: dict,
+ last_describe_job_call: float,
):
"""Return the training job info associated with job_name and print
CloudWatch logs"""
log_group = '/aws/sagemaker/TrainingJobs'
@@ -635,7 +629,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
state = LogState.JOB_COMPLETE
return state, last_description, last_describe_job_call
- def describe_tuning_job(self, name: str):
+ def describe_tuning_job(self, name: str) -> dict:
"""
Return the tuning job info associated with the name
@@ -645,7 +639,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
"""
return
self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
- def describe_model(self, name: str):
+ def describe_model(self, name: str) -> dict:
"""
Return the SageMaker model info associated with the name
@@ -655,7 +649,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
"""
return self.get_conn().describe_model(ModelName=name)
- def describe_transform_job(self, name: str):
+ def describe_transform_job(self, name: str) -> dict:
"""
Return the transform job info associated with the name
@@ -665,7 +659,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
"""
return self.get_conn().describe_transform_job(TransformJobName=name)
- def describe_processing_job(self, name: str):
+ def describe_processing_job(self, name: str) -> dict:
"""
Return the processing job info associated with the name
@@ -675,7 +669,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
"""
return self.get_conn().describe_processing_job(ProcessingJobName=name)
- def describe_endpoint_config(self, name: str):
+ def describe_endpoint_config(self, name: str) -> dict:
"""
Return the endpoint config info associated with the name
@@ -685,7 +679,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
"""
return
self.get_conn().describe_endpoint_config(EndpointConfigName=name)
- def describe_endpoint(self, name: str):
+ def describe_endpoint(self, name: str) -> dict:
"""
:param name: the name of the endpoint
:type name: str
@@ -697,7 +691,7 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
self,
job_name: str,
key: str,
- describe_function,
+ describe_function: Callable,
check_interval: int,
max_ingestion_time: Optional[int] = None,
non_terminal_states: Optional[Set] = None,
@@ -916,7 +910,9 @@ class SageMakerHook(AwsBaseHook): # pylint:
disable=too-many-public-methods
)
return results
- def _list_request(self, partial_func, result_key: str, max_results:
Optional[int] = None) -> List[Dict]:
+ def _list_request(
+ self, partial_func: Callable, result_key: str, max_results:
Optional[int] = None
+ ) -> List[Dict]:
"""
All AWS boto3 list_* requests return results in batches (if the key
"NextToken" is contained in the
result, there are more results to fetch). The default AWS batch size
is 10, and configurable up to
diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index 5305784..fdd8154 100644
--- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -39,7 +39,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
:type filename_template: str
"""
- def __init__(self, base_log_folder, log_group_arn, filename_template):
+ def __init__(self, base_log_folder: str, log_group_arn: str,
filename_template: str):
super().__init__(base_log_folder, filename_template)
split_arn = log_group_arn.split(':')
@@ -99,7 +99,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
{'end_of_log': True},
)
- def get_cloudwatch_logs(self, stream_name):
+ def get_cloudwatch_logs(self, stream_name: str) -> str:
"""
Return all logs from the given log stream.
diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py
b/airflow/providers/amazon/aws/log/s3_task_handler.py
index 8b32a2f..922e9ec 100644
--- a/airflow/providers/amazon/aws/log/s3_task_handler.py
+++ b/airflow/providers/amazon/aws/log/s3_task_handler.py
@@ -31,7 +31,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
uploads to and reads from S3 remote storage.
"""
- def __init__(self, base_log_folder, s3_log_folder, filename_template):
+ def __init__(self, base_log_folder: str, s3_log_folder: str,
filename_template: str):
super().__init__(base_log_folder, filename_template)
self.remote_base = s3_log_folder
self.log_relative_path = ''
@@ -119,11 +119,12 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
else:
return super()._read(ti, try_number)
- def s3_log_exists(self, remote_log_location):
+ def s3_log_exists(self, remote_log_location: str) -> bool:
"""
Check if remote_log_location exists in remote storage
:param remote_log_location: log's location in remote storage
+ :type remote_log_location: str
:return: True if location exists else False
"""
try:
@@ -132,7 +133,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
pass
return False
- def s3_read(self, remote_log_location, return_error=False):
+ def s3_read(self, remote_log_location: str, return_error: bool = False) ->
str:
"""
Returns the log found at the remote_log_location. Returns '' if no
logs are found or there is an error.
@@ -142,6 +143,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
:param return_error: if True, returns a string error message if an
error occurs. Otherwise returns '' when an error occurs.
:type return_error: bool
+ :return: the log found at the remote_log_location
"""
try:
return self.hook.read_key(remote_log_location)
@@ -151,8 +153,9 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
# return error if needed
if return_error:
return msg
+ return ''
- def s3_write(self, log, remote_log_location, append=True):
+ def s3_write(self, log: str, remote_log_location: str, append: bool =
True):
"""
Writes the log to the remote_log_location. Fails silently if no hook
was created.
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index eb0c4a4..82868e6 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -26,7 +26,7 @@ An Airflow operator for AWS Batch services
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
-from typing import Dict, Optional
+from typing import Dict, Optional, Any
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -48,13 +48,13 @@ class AwsBatchOperator(BaseOperator):
:type job_queue: str
:param overrides: the `containerOverrides` parameter for boto3 (templated)
- :type overrides: Dict
+ :type overrides: Optional[dict]
:param array_properties: the `arrayProperties` parameter for boto3
- :type array_properties: Dict
+ :type array_properties: Optional[dict]
:param parameters: the `parameters` for boto3 (templated)
- :type parameters: Dict
+ :type parameters: Optional[dict]
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
@@ -101,18 +101,18 @@ class AwsBatchOperator(BaseOperator):
def __init__(
self,
*,
- job_name,
- job_definition,
- job_queue,
- overrides,
- array_properties=None,
- parameters=None,
- job_id=None,
- waiters=None,
- max_retries=None,
- status_retries=None,
- aws_conn_id=None,
- region_name=None,
+ job_name: str,
+ job_definition: str,
+ job_queue: str,
+ overrides: dict,
+ array_properties: Optional[dict] = None,
+ parameters: Optional[dict] = None,
+ job_id: Optional[str] = None,
+ waiters: Optional[Any] = None,
+ max_retries: Optional[int] = None,
+ status_retries: Optional[int] = None,
+ aws_conn_id: Optional[str] = None,
+ region_name: Optional[str] = None,
**kwargs,
): # pylint: disable=too-many-arguments
@@ -121,9 +121,9 @@ class AwsBatchOperator(BaseOperator):
self.job_name = job_name
self.job_definition = job_definition
self.job_queue = job_queue
- self.overrides = overrides
+ self.overrides = overrides or {}
self.array_properties = array_properties or {}
- self.parameters = parameters
+ self.parameters = parameters or {}
self.waiters = waiters
self.hook = AwsBatchClientHook(
max_retries=max_retries,
@@ -181,6 +181,9 @@ class AwsBatchOperator(BaseOperator):
:raises: AirflowException
"""
+ if not self.job_id:
+ raise AirflowException('AWS Batch job - job_id was not found')
+
try:
if self.waiters:
self.waiters.wait_for_job(self.job_id)
diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py
b/airflow/providers/amazon/aws/operators/cloud_formation.py
index d6fa654..0f511c3 100644
--- a/airflow/providers/amazon/aws/operators/cloud_formation.py
+++ b/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains CloudFormation create/delete stack operators."""
-from typing import List
+from typing import List, Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.cloud_formation import
AWSCloudFormationHook
@@ -43,7 +43,7 @@ class CloudFormationCreateStackOperator(BaseOperator):
ui_color = '#6b9659'
@apply_defaults
- def __init__(self, *, stack_name, params, aws_conn_id='aws_default',
**kwargs):
+ def __init__(self, *, stack_name: str, params: dict, aws_conn_id: str =
'aws_default', **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.params = params
@@ -77,11 +77,12 @@ class CloudFormationDeleteStackOperator(BaseOperator):
ui_fgcolor = '#FFF'
@apply_defaults
- def __init__(self, *, stack_name, params=None, aws_conn_id='aws_default',
**kwargs):
+ def __init__(
+ self, *, stack_name: str, params: Optional[dict] = None, aws_conn_id:
str = 'aws_default', **kwargs
+ ):
super().__init__(**kwargs)
self.params = params or {}
self.stack_name = stack_name
- self.params = params
self.aws_conn_id = aws_conn_id
def execute(self, context):
diff --git a/airflow/providers/amazon/aws/operators/datasync.py
b/airflow/providers/amazon/aws/operators/datasync.py
index 9f2e9b3..b70ed96 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -19,6 +19,7 @@
import logging
import random
+from typing import Optional, List
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -108,19 +109,19 @@ class AWSDataSyncOperator(BaseOperator):
def __init__(
self,
*,
- aws_conn_id="aws_default",
- wait_interval_seconds=5,
- task_arn=None,
- source_location_uri=None,
- destination_location_uri=None,
- allow_random_task_choice=False,
- allow_random_location_choice=False,
- create_task_kwargs=None,
- create_source_location_kwargs=None,
- create_destination_location_kwargs=None,
- update_task_kwargs=None,
- task_execution_kwargs=None,
- delete_task_after_execution=False,
+ aws_conn_id: str = "aws_default",
+ wait_interval_seconds: int = 5,
+ task_arn: Optional[str] = None,
+ source_location_uri: Optional[str] = None,
+ destination_location_uri: Optional[str] = None,
+ allow_random_task_choice: bool = False,
+ allow_random_location_choice: bool = False,
+ create_task_kwargs: Optional[dict] = None,
+ create_source_location_kwargs: Optional[dict] = None,
+ create_destination_location_kwargs: Optional[dict] = None,
+ update_task_kwargs: Optional[dict] = None,
+ task_execution_kwargs: Optional[dict] = None,
+ delete_task_after_execution: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -163,27 +164,29 @@ class AWSDataSyncOperator(BaseOperator):
)
# Others
- self.hook = None
+ self.hook: Optional[AWSDataSyncHook] = None
# Candidates - these are found in AWS as possible things
# for us to use
- self.candidate_source_location_arns = None
- self.candidate_destination_location_arns = None
- self.candidate_task_arns = None
+ self.candidate_source_location_arns: Optional[List[str]] = None
+ self.candidate_destination_location_arns: Optional[List[str]] = None
+ self.candidate_task_arns: Optional[List[str]] = None
# Actuals
- self.source_location_arn = None
- self.destination_location_arn = None
- self.task_execution_arn = None
+ self.source_location_arn: Optional[str] = None
+ self.destination_location_arn: Optional[str] = None
+ self.task_execution_arn: Optional[str] = None
- def get_hook(self):
+ def get_hook(self) -> AWSDataSyncHook:
"""Create and return AWSDataSyncHook.
:return AWSDataSyncHook: An AWSDataSyncHook instance.
"""
- if not self.hook:
- self.hook = AWSDataSyncHook(
- aws_conn_id=self.aws_conn_id,
- wait_interval_seconds=self.wait_interval_seconds,
- )
+ if self.hook:
+ return self.hook
+
+ self.hook = AWSDataSyncHook(
+ aws_conn_id=self.aws_conn_id,
+ wait_interval_seconds=self.wait_interval_seconds,
+ )
return self.hook
def execute(self, context):
@@ -221,7 +224,7 @@ class AWSDataSyncOperator(BaseOperator):
return {"TaskArn": self.task_arn, "TaskExecutionArn":
self.task_execution_arn}
- def _get_tasks_and_locations(self):
+ def _get_tasks_and_locations(self) -> None:
"""Find existing DataSync Task based on source and dest Locations."""
hook = self.get_hook()
@@ -244,7 +247,7 @@ class AWSDataSyncOperator(BaseOperator):
)
self.log.info("Found candidate DataSync TaskArns %s",
self.candidate_task_arns)
- def choose_task(self, task_arn_list):
+ def choose_task(self, task_arn_list: list) -> Optional[str]:
"""Select 1 DataSync TaskArn from a list"""
if not task_arn_list:
return None
@@ -258,7 +261,7 @@ class AWSDataSyncOperator(BaseOperator):
return random.choice(task_arn_list)
raise AirflowException("Unable to choose a Task from
{}".format(task_arn_list))
- def choose_location(self, location_arn_list):
+ def choose_location(self, location_arn_list: List[str]) -> Optional[str]:
"""Select 1 DataSync LocationArn from a list"""
if not location_arn_list:
return None
@@ -272,12 +275,15 @@ class AWSDataSyncOperator(BaseOperator):
return random.choice(location_arn_list)
raise AirflowException("Unable to choose a Location from
{}".format(location_arn_list))
- def _create_datasync_task(self):
+ def _create_datasync_task(self) -> None:
"""Create a AWS DataSyncTask."""
+ if not self.candidate_source_location_arns or not
self.candidate_destination_location_arns:
+ return
+
hook = self.get_hook()
self.source_location_arn =
self.choose_location(self.candidate_source_location_arns)
- if not self.source_location_arn and self.create_source_location_kwargs:
+ if not self.source_location_arn and self.source_location_uri and
self.create_source_location_kwargs:
self.log.info('Attempting to create source Location')
self.source_location_arn = hook.create_location(
self.source_location_uri, **self.create_source_location_kwargs
@@ -288,7 +294,11 @@ class AWSDataSyncOperator(BaseOperator):
)
self.destination_location_arn =
self.choose_location(self.candidate_destination_location_arns)
- if not self.destination_location_arn and
self.create_destination_location_kwargs:
+ if (
+ not self.destination_location_arn
+ and self.destination_location_uri
+ and self.create_destination_location_kwargs
+ ):
self.log.info('Attempting to create destination Location')
self.destination_location_arn = hook.create_location(
self.destination_location_uri,
**self.create_destination_location_kwargs
@@ -305,18 +315,22 @@ class AWSDataSyncOperator(BaseOperator):
if not self.task_arn:
raise AirflowException("Task could not be created")
self.log.info("Created a Task with TaskArn %s", self.task_arn)
- return self.task_arn
- def _update_datasync_task(self):
+ def _update_datasync_task(self) -> None:
"""Update a AWS DataSyncTask."""
+ if not self.task_arn:
+ return
+
hook = self.get_hook()
self.log.info("Updating TaskArn %s", self.task_arn)
hook.update_task(self.task_arn, **self.update_task_kwargs)
self.log.info("Updated TaskArn %s", self.task_arn)
- return self.task_arn
- def _execute_datasync_task(self):
+ def _execute_datasync_task(self) -> None:
"""Create and monitor an AWSDataSync TaskExecution for a Task."""
+ if not self.task_arn:
+ raise AirflowException("Missing TaskArn")
+
hook = self.get_hook()
# Create a task execution:
@@ -340,9 +354,8 @@ class AWSDataSyncOperator(BaseOperator):
if not result:
raise AirflowException("Failed TaskExecutionArn %s" %
self.task_execution_arn)
- return self.task_execution_arn
- def on_kill(self):
+ def on_kill(self) -> None:
"""Cancel the submitted DataSync task."""
hook = self.get_hook()
if self.task_execution_arn:
@@ -350,16 +363,18 @@ class AWSDataSyncOperator(BaseOperator):
hook.cancel_task_execution(task_execution_arn=self.task_execution_arn)
self.log.info("Cancelled TaskExecutionArn %s",
self.task_execution_arn)
- def _delete_datasync_task(self):
+ def _delete_datasync_task(self) -> None:
"""Deletes an AWS DataSync Task."""
+ if not self.task_arn:
+ return
+
hook = self.get_hook()
# Delete task:
self.log.info("Deleting Task with TaskArn %s", self.task_arn)
hook.delete_task(self.task_arn)
self.log.info("Task Deleted")
- return self.task_arn
- def _get_location_arns(self, location_uri):
+ def _get_location_arns(self, location_uri) -> List[str]:
location_arns = self.get_hook().get_location_arns(location_uri)
self.log.info("Found LocationArns %s for LocationUri %s",
location_arns, location_uri)
return location_arns
diff --git a/airflow/providers/amazon/aws/operators/ecs.py
b/airflow/providers/amazon/aws/operators/ecs.py
index e2e85d9..e7f5abb 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -18,7 +18,9 @@
import re
import sys
from datetime import datetime
-from typing import Dict, Optional
+from typing import Optional
+
+from botocore.waiter import Waiter
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -40,19 +42,19 @@ class ECSProtocol(Protocol):
-
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html
"""
- def run_task(self, **kwargs):
+ def run_task(self, **kwargs) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task"""
# noqa: E501 # pylint: disable=line-too-long
...
- def get_waiter(self, x: str):
+ def get_waiter(self, x: str) -> Waiter:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter"""
# noqa: E501 # pylint: disable=line-too-long
...
- def describe_tasks(self, cluster: str, tasks) -> Dict:
+ def describe_tasks(self, cluster: str, tasks) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks"""
# noqa: E501 # pylint: disable=line-too-long
...
- def stop_task(self, cluster, task, reason: str) -> Dict:
+ def stop_task(self, cluster, task, reason: str) -> dict:
"""https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task"""
# noqa: E501 # pylint: disable=line-too-long
...
@@ -111,30 +113,28 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
"""
ui_color = '#f0ede4'
- client = None # type: Optional[ECSProtocol]
- arn = None # type: Optional[str]
template_fields = ('overrides',)
@apply_defaults
def __init__(
self,
*,
- task_definition,
- cluster,
- overrides, # pylint: disable=too-many-arguments
- aws_conn_id=None,
- region_name=None,
- launch_type='EC2',
- group=None,
- placement_constraints=None,
- placement_strategy=None,
- platform_version='LATEST',
- network_configuration=None,
- tags=None,
- awslogs_group=None,
- awslogs_region=None,
- awslogs_stream_prefix=None,
- propagate_tags=None,
+ task_definition: str,
+ cluster: str,
+ overrides: dict, # pylint: disable=too-many-arguments
+ aws_conn_id: Optional[str] = None,
+ region_name: Optional[str] = None,
+ launch_type: str = 'EC2',
+ group: Optional[str] = None,
+ placement_constraints: Optional[list] = None,
+ placement_strategy: Optional[list] = None,
+ platform_version: str = 'LATEST',
+ network_configuration: Optional[dict] = None,
+ tags: Optional[dict] = None,
+ awslogs_group: Optional[str] = None,
+ awslogs_region: Optional[str] = None,
+ awslogs_stream_prefix: Optional[str] = None,
+ propagate_tags: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -160,7 +160,9 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
if self.awslogs_region is None:
self.awslogs_region = region_name
- self.hook = None
+ self.hook: Optional[AwsBaseHook] = None
+ self.client: Optional[ECSProtocol] = None
+ self.arn: Optional[str] = None
def execute(self, context):
self.log.info(
@@ -207,12 +209,18 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
self._check_success_task()
self.log.info('ECS Task has been successfully executed: %s', response)
- def _wait_for_task_ended(self):
+ def _wait_for_task_ended(self) -> None:
+ if not self.client or not self.arn:
+ return
+
waiter = self.client.get_waiter('tasks_stopped')
waiter.config.max_attempts = sys.maxsize # timeout is managed by
airflow
waiter.wait(cluster=self.cluster, tasks=[self.arn])
- def _check_success_task(self):
+ def _check_success_task(self) -> None:
+ if not self.client or not self.arn:
+ return
+
response = self.client.describe_tasks(cluster=self.cluster,
tasks=[self.arn])
self.log.info('ECS Task stopped, check status: %s', response)
@@ -252,19 +260,22 @@ class ECSOperator(BaseOperator): # pylint:
disable=too-many-instance-attributes
)
)
- def get_hook(self):
+ def get_hook(self) -> AwsBaseHook:
"""Create and return an AwsHook."""
- if not self.hook:
- self.hook = AwsBaseHook(
- aws_conn_id=self.aws_conn_id, client_type='ecs',
region_name=self.region_name
- )
+ if self.hook:
+ return self.hook
+
+ self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id,
client_type='ecs', region_name=self.region_name)
return self.hook
- def get_logs_hook(self):
+ def get_logs_hook(self) -> AwsLogsHook:
"""Create and return an AwsLogsHook."""
return AwsLogsHook(aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region)
- def on_kill(self):
+ def on_kill(self) -> None:
+ if not self.client or not self.arn:
+ return
+
response = self.client.stop_task(
cluster=self.cluster, task=self.arn, reason='Task killed by the
user'
)
diff --git a/airflow/providers/amazon/aws/operators/glue.py
b/airflow/providers/amazon/aws/operators/glue.py
index 991135f..48d5661 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -18,6 +18,7 @@
from __future__ import unicode_literals
import os.path
+from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue import AwsGlueJobHook
@@ -61,24 +62,24 @@ class AwsGlueJobOperator(BaseOperator):
def __init__(
self,
*,
- job_name='aws_glue_default_job',
- job_desc='AWS Glue Job with Airflow',
- script_location=None,
- concurrent_run_limit=None,
- script_args=None,
- retry_limit=None,
- num_of_dpus=6,
- aws_conn_id='aws_default',
- region_name=None,
- s3_bucket=None,
- iam_role_name=None,
+ job_name: str = 'aws_glue_default_job',
+ job_desc: str = 'AWS Glue Job with Airflow',
+ script_location: Optional[str] = None,
+ concurrent_run_limit: Optional[int] = None,
+ script_args: Optional[dict] = None,
+ retry_limit: Optional[int] = None,
+ num_of_dpus: int = 6,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ s3_bucket: Optional[str] = None,
+ iam_role_name: Optional[str] = None,
**kwargs,
): # pylint: disable=too-many-arguments
super(AwsGlueJobOperator, self).__init__(**kwargs)
self.job_name = job_name
self.job_desc = job_desc
self.script_location = script_location
- self.concurrent_run_limit = concurrent_run_limit
+ self.concurrent_run_limit = concurrent_run_limit or 1
self.script_args = script_args or {}
self.retry_limit = retry_limit
self.num_of_dpus = num_of_dpus
@@ -87,7 +88,7 @@ class AwsGlueJobOperator(BaseOperator):
self.s3_bucket = s3_bucket
self.iam_role_name = iam_role_name
self.s3_protocol = "s3://"
- self.s3_artifcats_prefix = 'artifacts/glue-scripts/'
+ self.s3_artifacts_prefix = 'artifacts/glue-scripts/'
def execute(self, context):
"""
@@ -98,7 +99,7 @@ class AwsGlueJobOperator(BaseOperator):
if self.script_location and not
self.script_location.startswith(self.s3_protocol):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
- s3_hook.load_file(self.script_location, self.s3_bucket,
self.s3_artifcats_prefix + script_name)
+ s3_hook.load_file(self.script_location, self.s3_bucket,
self.s3_artifacts_prefix + script_name)
glue_job = AwsGlueJobHook(
job_name=self.job_name,
desc=self.job_desc,
diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py
b/airflow/providers/amazon/aws/operators/s3_bucket.py
index 14d25cc..bb15baa 100644
--- a/airflow/providers/amazon/aws/operators/s3_bucket.py
+++ b/airflow/providers/amazon/aws/operators/s3_bucket.py
@@ -43,7 +43,7 @@ class S3CreateBucketOperator(BaseOperator):
def __init__(
self,
*,
- bucket_name,
+ bucket_name: str,
aws_conn_id: Optional[str] = "aws_default",
region_name: Optional[str] = None,
**kwargs,
@@ -81,7 +81,7 @@ class S3DeleteBucketOperator(BaseOperator):
def __init__(
self,
- bucket_name,
+ bucket_name: str,
force_delete: bool = False,
aws_conn_id: Optional[str] = "aws_default",
**kwargs,
diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py
b/airflow/providers/amazon/aws/operators/s3_copy_object.py
index 4b2d290..052c9ad 100644
--- a/airflow/providers/amazon/aws/operators/s3_copy_object.py
+++ b/airflow/providers/amazon/aws/operators/s3_copy_object.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional, Union
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -70,13 +71,13 @@ class S3CopyObjectOperator(BaseOperator):
def __init__(
self,
*,
- source_bucket_key,
- dest_bucket_key,
- source_bucket_name=None,
- dest_bucket_name=None,
- source_version_id=None,
- aws_conn_id='aws_default',
- verify=None,
+ source_bucket_key: str,
+ dest_bucket_key: str,
+ source_bucket_name: Optional[str] = None,
+ dest_bucket_name: Optional[str] = None,
+ source_version_id: Optional[str] = None,
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[str, bool]] = None,
**kwargs,
):
super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py
b/airflow/providers/amazon/aws/operators/s3_delete_objects.py
index b6d267b..96c9e14 100644
--- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py
+++ b/airflow/providers/amazon/aws/operators/s3_delete_objects.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional, Union
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -62,7 +63,16 @@ class S3DeleteObjectsOperator(BaseOperator):
template_fields = ('keys', 'bucket', 'prefix')
@apply_defaults
- def __init__(self, *, bucket, keys=None, prefix=None,
aws_conn_id='aws_default', verify=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ bucket: str,
+ keys: Optional[Union[str, list]] = None,
+ prefix: Optional[str] = None,
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[str, bool]] = None,
+ **kwargs,
+ ):
if not bool(keys) ^ bool(prefix):
raise ValueError("Either keys or prefix should be set.")
diff --git a/airflow/providers/amazon/aws/operators/s3_list.py
b/airflow/providers/amazon/aws/operators/s3_list.py
index 4c25e99..58d599d 100644
--- a/airflow/providers/amazon/aws/operators/s3_list.py
+++ b/airflow/providers/amazon/aws/operators/s3_list.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterable
+from typing import Iterable, Optional, Union
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -70,7 +70,16 @@ class S3ListOperator(BaseOperator):
ui_color = '#ffd700'
@apply_defaults
- def __init__(self, *, bucket, prefix='', delimiter='',
aws_conn_id='aws_default', verify=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ bucket: str,
+ prefix: str = '',
+ delimiter: str = '',
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[str, bool]] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.bucket = bucket
self.prefix = prefix
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py
b/airflow/providers/amazon/aws/operators/sagemaker_base.py
index 19fb921..3fa1b2e 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_base.py
@@ -19,6 +19,8 @@
import json
from typing import Iterable
+from cached_property import cached_property
+
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.utils.decorators import apply_defaults
@@ -41,12 +43,11 @@ class SageMakerBaseOperator(BaseOperator):
integer_fields = [] # type: Iterable[Iterable[str]]
@apply_defaults
- def __init__(self, *, config, aws_conn_id='aws_default', **kwargs):
+ def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default',
**kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.config = config
- self.hook = None
def parse_integer(self, config, field):
"""Recursive method for parsing string fields holding integer values
to integers."""
@@ -84,7 +85,6 @@ class SageMakerBaseOperator(BaseOperator):
def preprocess_config(self):
"""Process the config into a usable form."""
self.log.info('Preprocessing the config and doing required
s3_operations')
- self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
@@ -97,3 +97,8 @@ class SageMakerBaseOperator(BaseOperator):
def execute(self, context):
raise NotImplementedError('Please implement execute() in sub class!')
+
+ @cached_property
+ def hook(self):
+ """Return SageMakerHook"""
+ return SageMakerHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
index c7a89f2..53cfd93 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from botocore.exceptions import ClientError
@@ -74,11 +75,11 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config,
- wait_for_completion=True,
- check_interval=30,
- max_ingestion_time=None,
- operation='create',
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ operation: str = 'create',
**kwargs,
):
super().__init__(config=config, **kwargs)
@@ -92,12 +93,12 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
raise ValueError('Invalid value! Argument operation has to be one
of "create" and "update"')
self.create_integer_fields()
- def create_integer_fields(self):
+ def create_integer_fields(self) -> None:
"""Set fields which should be casted to integers."""
if 'EndpointConfig' in self.config:
self.integer_fields = [['EndpointConfig', 'ProductionVariants',
'InitialInstanceCount']]
- def expand_role(self):
+ def expand_role(self) -> None:
if 'Model' not in self.config:
return
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
@@ -105,7 +106,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
if 'ExecutionRoleArn' in config:
config['ExecutionRoleArn'] =
hook.expand_role(config['ExecutionRoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
model_info = self.config.get('Model')
diff --git
a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
index 9bde451..bbf2be1 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
@@ -38,12 +38,12 @@ class
SageMakerEndpointConfigOperator(SageMakerBaseOperator):
integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
@apply_defaults
- def __init__(self, *, config, **kwargs):
+ def __init__(self, *, config: dict, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
self.log.info('Creating SageMaker Endpoint Config %s.',
self.config['EndpointConfigName'])
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py
b/airflow/providers/amazon/aws/operators/sagemaker_model.py
index 122ceee..25730ea 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_model.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_model.py
@@ -42,12 +42,12 @@ class SageMakerModelOperator(SageMakerBaseOperator):
self.config = config
- def expand_role(self):
+ def expand_role(self) -> None:
if 'ExecutionRoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['ExecutionRoleArn'] =
hook.expand_role(self.config['ExecutionRoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py
b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
index c1bcac7..e56a987 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -55,12 +56,12 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config,
- aws_conn_id,
- wait_for_completion=True,
- print_log=True,
- check_interval=30,
- max_ingestion_time=None,
+ config: dict,
+ aws_conn_id: str,
+ wait_for_completion: bool = True,
+ print_log: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
action_if_job_exists: str = "increment", # TODO use typing.Literal
for this in Python 3.8
**kwargs,
):
@@ -78,7 +79,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
self.max_ingestion_time = max_ingestion_time
self._create_integer_fields()
- def _create_integer_fields(self):
+ def _create_integer_fields(self) -> None:
"""Set fields which should be casted to integers."""
self.integer_fields = [
['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
@@ -87,12 +88,12 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
if 'StoppingCondition' in self.config:
self.integer_fields += [['StoppingCondition',
'MaxRuntimeInSeconds']]
- def expand_role(self):
+ def expand_role(self) -> None:
if 'RoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
processing_job_name = self.config["ProcessingJobName"]
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py
b/airflow/providers/amazon/aws/operators/sagemaker_training.py
index 6175a61..29c34f6 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -61,11 +62,11 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
def __init__(
self,
*,
- config,
- wait_for_completion=True,
- print_log=True,
- check_interval=30,
- max_ingestion_time=None,
+ config: dict,
+ wait_for_completion: bool = True,
+ print_log: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
action_if_job_exists: str = "increment", # TODO use typing.Literal
for this in Python 3.8
**kwargs,
):
@@ -84,12 +85,12 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
f"Provided value: '{action_if_job_exists}'."
)
- def expand_role(self):
+ def expand_role(self) -> None:
if 'RoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
training_job_name = self.config["TrainingJobName"]
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
index 7ae8f3a..1dadb3d 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional, List
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -63,7 +64,13 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
@apply_defaults
def __init__(
- self, *, config, wait_for_completion=True, check_interval=30,
max_ingestion_time=None, **kwargs
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ **kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
@@ -72,9 +79,9 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
self.max_ingestion_time = max_ingestion_time
self.create_integer_fields()
- def create_integer_fields(self):
+ def create_integer_fields(self) -> None:
"""Set fields which should be casted to integers."""
- self.integer_fields = [
+ self.integer_fields: List[List[str]] = [
['Transform', 'TransformResources', 'InstanceCount'],
['Transform', 'MaxConcurrentTransforms'],
['Transform', 'MaxPayloadInMB'],
@@ -83,7 +90,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
for field in self.integer_fields:
field.pop(0)
- def expand_role(self):
+ def expand_role(self) -> None:
if 'Model' not in self.config:
return
config = self.config['Model']
@@ -91,7 +98,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
config['ExecutionRoleArn'] =
hook.expand_role(config['ExecutionRoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
model_config = self.config.get('Model')
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
index 483e541..f8df36a 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -56,7 +57,13 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
@apply_defaults
def __init__(
- self, *, config, wait_for_completion=True, check_interval=30,
max_ingestion_time=None, **kwargs
+ self,
+ *,
+ config: dict,
+ wait_for_completion: bool = True,
+ check_interval: int = 30,
+ max_ingestion_time: Optional[int] = None,
+ **kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
@@ -64,14 +71,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
- def expand_role(self):
+ def expand_role(self) -> None:
if 'TrainingJobDefinition' in self.config:
config = self.config['TrainingJobDefinition']
if 'RoleArn' in config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
config['RoleArn'] = hook.expand_role(config['RoleArn'])
- def execute(self, context):
+ def execute(self, context) -> dict:
self.preprocess_config()
self.log.info(
diff --git a/airflow/providers/amazon/aws/operators/sns.py
b/airflow/providers/amazon/aws/operators/sns.py
index 8917dfe..1e88913 100644
--- a/airflow/providers/amazon/aws/operators/sns.py
+++ b/airflow/providers/amazon/aws/operators/sns.py
@@ -17,6 +17,7 @@
# under the License.
"""Publish message to SNS queue"""
+from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sns import AwsSnsHook
@@ -47,11 +48,11 @@ class SnsPublishOperator(BaseOperator):
def __init__(
self,
*,
- target_arn,
- message,
- aws_conn_id='aws_default',
- subject=None,
- message_attributes=None,
+ target_arn: str,
+ message: str,
+ aws_conn_id: str = 'aws_default',
+ subject: Optional[str] = None,
+ message_attributes: Optional[dict] = None,
**kwargs,
):
super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/operators/sqs.py
b/airflow/providers/amazon/aws/operators/sqs.py
index 6005195..afc50e4 100644
--- a/airflow/providers/amazon/aws/operators/sqs.py
+++ b/airflow/providers/amazon/aws/operators/sqs.py
@@ -16,6 +16,7 @@
# under the License.
"""Publish message to SQS queue"""
+from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sqs import SQSHook
@@ -46,11 +47,11 @@ class SQSPublishOperator(BaseOperator):
def __init__(
self,
*,
- sqs_queue,
- message_content,
- message_attributes=None,
- delay_seconds=0,
- aws_conn_id='aws_default',
+ sqs_queue: str,
+ message_content: str,
+ message_attributes: Optional[dict] = None,
+ delay_seconds: int = 0,
+ aws_conn_id: str = 'aws_default',
**kwargs,
):
super().__init__(**kwargs)
diff --git
a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
index 2eaa2c4..769f06c 100644
---
a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
+++
b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
@@ -16,6 +16,7 @@
# under the License.
import json
+from typing import Optional
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
@@ -42,7 +43,14 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
ui_color = '#f9c915'
@apply_defaults
- def __init__(self, *, execution_arn: str, aws_conn_id='aws_default',
region_name=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ execution_arn: str,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
diff --git
a/airflow/providers/amazon/aws/operators/step_function_start_execution.py
b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
index 0d8f446..b364ba5 100644
--- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py
+++ b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
@@ -55,8 +55,8 @@ class StepFunctionStartExecutionOperator(BaseOperator):
state_machine_arn: str,
name: Optional[str] = None,
state_machine_input: Union[dict, str, None] = None,
- aws_conn_id='aws_default',
- region_name=None,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py
b/airflow/providers/amazon/aws/sensors/cloud_formation.py
index 5c66a0c..2da691b 100644
--- a/airflow/providers/amazon/aws/sensors/cloud_formation.py
+++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains sensors for AWS CloudFormation."""
+from typing import Optional
+
from airflow.providers.amazon.aws.hooks.cloud_formation import
AWSCloudFormationHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
@@ -69,12 +71,19 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
ui_color = '#C5CAE9'
@apply_defaults
- def __init__(self, *, stack_name, aws_conn_id='aws_default',
region_name=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ stack_name: str,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.stack_name = stack_name
- self.hook = None
+ self.hook: Optional[AWSCloudFormationHook] = None
def poke(self, context):
stack_status = self.get_hook().get_stack_status(self.stack_name)
@@ -84,8 +93,10 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
return False
raise ValueError(f'Stack {self.stack_name} in bad state:
{stack_status}')
- def get_hook(self):
+ def get_hook(self) -> AWSCloudFormationHook:
"""Create and return an AWSCloudFormationHook"""
- if not self.hook:
- self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ if self.hook:
+ return self.hook
+
+ self.hook = AWSCloudFormationHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py
b/airflow/providers/amazon/aws/sensors/emr_base.py
index f05197b..d862c6b 100644
--- a/airflow/providers/amazon/aws/sensors/emr_base.py
+++ b/airflow/providers/amazon/aws/sensors/emr_base.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Iterable
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
@@ -42,17 +42,19 @@ class EmrBaseSensor(BaseSensorOperator):
ui_color = '#66c3ff'
@apply_defaults
- def __init__(self, *, aws_conn_id='aws_default', **kwargs):
+ def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
- self.target_states = None # will be set in subclasses
- self.failed_states = None # will be set in subclasses
- self.hook = None
+ self.target_states: Optional[Iterable[str]] = None # will be set in
subclasses
+ self.failed_states: Optional[Iterable[str]] = None # will be set in
subclasses
+ self.hook: Optional[EmrHook] = None
- def get_hook(self):
+ def get_hook(self) -> EmrHook:
"""Get EmrHook"""
- if not self.hook:
- self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
+ if self.hook:
+ return self.hook
+
+ self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
return self.hook
def poke(self, context):
diff --git a/airflow/providers/amazon/aws/sensors/glue.py
b/airflow/providers/amazon/aws/sensors/glue.py
index 7b2ce30..92876c3 100644
--- a/airflow/providers/amazon/aws/sensors/glue.py
+++ b/airflow/providers/amazon/aws/sensors/glue.py
@@ -36,7 +36,7 @@ class AwsGlueJobSensor(BaseSensorOperator):
template_fields = ('job_name', 'run_id')
@apply_defaults
- def __init__(self, *, job_name, run_id, aws_conn_id='aws_default',
**kwargs):
+ def __init__(self, *, job_name: str, run_id: str, aws_conn_id: str =
'aws_default', **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
self.run_id = run_id
diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
index 7292626..3849094 100644
--- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
+++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -59,12 +60,12 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
def __init__(
self,
*,
- table_name,
- expression="ds='{{ ds }}'",
- aws_conn_id='aws_default',
- region_name=None,
- database_name='default',
- poke_interval=60 * 3,
+ table_name: str,
+ expression: str = "ds='{{ ds }}'",
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ database_name: str = 'default',
+ poke_interval: int = 60 * 3,
**kwargs,
):
super().__init__(poke_interval=poke_interval, **kwargs)
@@ -73,7 +74,7 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
self.table_name = table_name
self.expression = expression
self.database_name = database_name
- self.hook = None
+ self.hook: Optional[AwsGlueCatalogHook] = None
def poke(self, context):
"""Checks for existence of the partition in the AWS Glue Catalog
table"""
@@ -85,8 +86,10 @@ class AwsGlueCatalogPartitionSensor(BaseSensorOperator):
return self.get_hook().check_for_partition(self.database_name,
self.table_name, self.expression)
- def get_hook(self):
+ def get_hook(self) -> AwsGlueCatalogHook:
"""Gets the AwsGlueCatalogHook"""
- if not self.hook:
- self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ if self.hook:
+ return self.hook
+
+ self.hook = AwsGlueCatalogHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/redshift.py
b/airflow/providers/amazon/aws/sensors/redshift.py
index 37f3521..106801a 100644
--- a/airflow/providers/amazon/aws/sensors/redshift.py
+++ b/airflow/providers/amazon/aws/sensors/redshift.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
from airflow.providers.amazon.aws.hooks.redshift import RedshiftHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -34,19 +35,28 @@ class AwsRedshiftClusterSensor(BaseSensorOperator):
template_fields = ('cluster_identifier', 'target_status')
@apply_defaults
- def __init__(self, *, cluster_identifier, target_status='available',
aws_conn_id='aws_default', **kwargs):
+ def __init__(
+ self,
+ *,
+ cluster_identifier: str,
+ target_status: str = 'available',
+ aws_conn_id: str = 'aws_default',
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.target_status = target_status
self.aws_conn_id = aws_conn_id
- self.hook = None
+ self.hook: Optional[RedshiftHook] = None
def poke(self, context):
self.log.info('Poking for status : %s\nfor cluster %s',
self.target_status, self.cluster_identifier)
return self.get_hook().cluster_status(self.cluster_identifier) ==
self.target_status
- def get_hook(self):
+ def get_hook(self) -> RedshiftHook:
"""Create and return a RedshiftHook"""
- if not self.hook:
- self.hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
+ if self.hook:
+ return self.hook
+
+ self.hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py
b/airflow/providers/amazon/aws/sensors/s3_key.py
index 0c0f6e3..9eab08e 100644
--- a/airflow/providers/amazon/aws/sensors/s3_key.py
+++ b/airflow/providers/amazon/aws/sensors/s3_key.py
@@ -15,8 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
+from typing import Optional, Union
from urllib.parse import urlparse
from airflow.exceptions import AirflowException
@@ -62,11 +61,11 @@ class S3KeySensor(BaseSensorOperator):
def __init__(
self,
*,
- bucket_key,
- bucket_name=None,
- wildcard_match=False,
- aws_conn_id='aws_default',
- verify=None,
+ bucket_key: str,
+ bucket_name: Optional[str] = None,
+ wildcard_match: bool = False,
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[str, bool]] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -91,7 +90,7 @@ class S3KeySensor(BaseSensorOperator):
self.wildcard_match = wildcard_match
self.aws_conn_id = aws_conn_id
self.verify = verify
- self.hook = None
+ self.hook: Optional[S3Hook] = None
def poke(self, context):
self.log.info('Poking for key : s3://%s/%s', self.bucket_name,
self.bucket_key)
@@ -99,8 +98,10 @@ class S3KeySensor(BaseSensorOperator):
return self.get_hook().check_for_wildcard_key(self.bucket_key,
self.bucket_name)
return self.get_hook().check_for_key(self.bucket_key, self.bucket_name)
- def get_hook(self):
+ def get_hook(self) -> S3Hook:
"""Create and return an S3Hook"""
- if not self.hook:
- self.hook = S3Hook(aws_conn_id=self.aws_conn_id,
verify=self.verify)
+ if self.hook:
+ return self.hook
+
+ self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py
b/airflow/providers/amazon/aws/sensors/s3_prefix.py
index 4dc4900..13fb37c 100644
--- a/airflow/providers/amazon/aws/sensors/s3_prefix.py
+++ b/airflow/providers/amazon/aws/sensors/s3_prefix.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional, Union
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -56,7 +57,14 @@ class S3PrefixSensor(BaseSensorOperator):
@apply_defaults
def __init__(
- self, *, bucket_name, prefix, delimiter='/',
aws_conn_id='aws_default', verify=None, **kwargs
+ self,
+ *,
+ bucket_name: str,
+ prefix: str,
+ delimiter: str = '/',
+ aws_conn_id: str = 'aws_default',
+ verify: Optional[Union[str, bool]] = None,
+ **kwargs,
):
super().__init__(**kwargs)
# Parse
@@ -66,7 +74,7 @@ class S3PrefixSensor(BaseSensorOperator):
self.full_url = "s3://" + bucket_name + '/' + prefix
self.aws_conn_id = aws_conn_id
self.verify = verify
- self.hook = None
+ self.hook: Optional[S3Hook] = None
def poke(self, context):
self.log.info('Poking for prefix : %s in bucket s3://%s', self.prefix,
self.bucket_name)
@@ -74,8 +82,10 @@ class S3PrefixSensor(BaseSensorOperator):
prefix=self.prefix, delimiter=self.delimiter,
bucket_name=self.bucket_name
)
- def get_hook(self):
+ def get_hook(self) -> S3Hook:
"""Create and return an S3Hook"""
- if not self.hook:
- self.hook = S3Hook(aws_conn_id=self.aws_conn_id,
verify=self.verify)
+ if self.hook:
+ return self.hook
+
+ self.hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py
b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
index 6704b1a..f55b6cc 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional, Set
+
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
@@ -32,15 +34,17 @@ class SageMakerBaseSensor(BaseSensorOperator):
ui_color = '#ededed'
@apply_defaults
- def __init__(self, *, aws_conn_id='aws_default', **kwargs):
+ def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
- self.hook = None
+ self.hook: Optional[SageMakerHook] = None
- def get_hook(self):
+ def get_hook(self) -> SageMakerHook:
"""Get SageMakerHook"""
- if not self.hook:
- self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
+ if self.hook:
+ return self.hook
+
+ self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
return self.hook
def poke(self, context):
@@ -62,22 +66,22 @@ class SageMakerBaseSensor(BaseSensorOperator):
raise AirflowException('Sagemaker job failed for the following
reason: %s' % failed_reason)
return True
- def non_terminal_states(self):
+ def non_terminal_states(self) -> Set[str]:
"""Placeholder for returning states with should not terminate."""
raise NotImplementedError('Please implement non_terminal_states() in
subclass')
- def failed_states(self):
+ def failed_states(self) -> Set[str]:
"""Placeholder for returning states with are considered failed."""
raise NotImplementedError('Please implement failed_states() in
subclass')
- def get_sagemaker_response(self):
+ def get_sagemaker_response(self) -> Optional[dict]:
"""Placeholder for checking status of a SageMaker task."""
raise NotImplementedError('Please implement get_sagemaker_response()
in subclass')
- def get_failed_reason_from_response(self, response): # pylint:
disable=unused-argument
+ def get_failed_reason_from_response(self, response: dict) -> str: #
pylint: disable=unused-argument
"""Placeholder for extracting the reason for failure from an AWS
response."""
return 'Unknown'
- def state_from_response(self, response):
+ def state_from_response(self, response: dict) -> str:
"""Placeholder for extracting the state from an AWS response."""
raise NotImplementedError('Please implement state_from_response() in
subclass')
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
index 36403b8..9cd7668 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Optional
import time
@@ -44,13 +45,13 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
self.print_log = print_log
self.positions = {}
self.stream_names = []
- self.instance_count = None
- self.state = None
+ self.instance_count: Optional[int] = None
+ self.state: Optional[int] = None
self.last_description = None
self.last_describe_job_call = None
self.log_resource_inited = False
- def init_log_resource(self, hook):
+ def init_log_resource(self, hook: SageMakerHook) -> None:
"""Set tailing LogState for associated training job."""
description = hook.describe_training_job(self.job_name)
self.instance_count = description['ResourceConfig']['InstanceCount']
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
index 4108c98..a751e56 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
@@ -35,7 +35,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *, job_name, **kwargs):
+ def __init__(self, *, job_name: str, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
index 794695b..96080e0 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
@@ -35,7 +35,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
template_ext = ()
@apply_defaults
- def __init__(self, *, job_name, **kwargs):
+ def __init__(self, *, job_name: str, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py
b/airflow/providers/amazon/aws/sensors/sqs.py
index b6c467d..e7a250d 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""Reads and then deletes the message from SQS queue"""
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sqs import SQSHook
@@ -43,14 +44,20 @@ class SQSSensor(BaseSensorOperator):
@apply_defaults
def __init__(
- self, *, sqs_queue, aws_conn_id='aws_default', max_messages=5,
wait_time_seconds=1, **kwargs
+ self,
+ *,
+ sqs_queue,
+ aws_conn_id: str = 'aws_default',
+ max_messages: int = 5,
+ wait_time_seconds: int = 1,
+ **kwargs,
):
super().__init__(**kwargs)
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
self.wait_time_seconds = wait_time_seconds
- self.hook = None
+ self.hook: Optional[SQSHook] = None
def poke(self, context):
"""
@@ -90,8 +97,10 @@ class SQSSensor(BaseSensorOperator):
return False
- def get_hook(self):
+ def get_hook(self) -> SQSHook:
"""Create and return an SQSHook"""
- if not self.hook:
- self.hook = SQSHook(aws_conn_id=self.aws_conn_id)
+ if self.hook:
+ return self.hook
+
+ self.hook = SQSHook(aws_conn_id=self.aws_conn_id)
return self.hook
diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py
b/airflow/providers/amazon/aws/sensors/step_function_execution.py
index 6126670..75c7e8b 100644
--- a/airflow/providers/amazon/aws/sensors/step_function_execution.py
+++ b/airflow/providers/amazon/aws/sensors/step_function_execution.py
@@ -16,6 +16,7 @@
# under the License.
import json
+from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
@@ -51,12 +52,19 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
ui_color = '#66c3ff'
@apply_defaults
- def __init__(self, *, execution_arn: str, aws_conn_id='aws_default',
region_name=None, **kwargs):
+ def __init__(
+ self,
+ *,
+ execution_arn: str,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- self.hook = None
+ self.hook: Optional[StepFunctionHook] = None
def poke(self, context):
execution_status =
self.get_hook().describe_execution(self.execution_arn)
@@ -73,8 +81,10 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
self.xcom_push(context, 'output', output)
return True
- def get_hook(self):
+ def get_hook(self) -> StepFunctionHook:
"""Create and return a StepFunctionHook"""
- if not self.hook:
- self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ if self.hook:
+ return self.hook
+
+ self.hook = StepFunctionHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
return self.hook
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index acc87a1..505a71a 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -86,7 +86,7 @@ class TestAwsBatchOperator(unittest.TestCase):
self.assertEqual(self.batch.waiters, None)
self.assertEqual(self.batch.hook.max_retries, self.MAX_RETRIES)
self.assertEqual(self.batch.hook.status_retries, self.STATUS_RETRIES)
- self.assertEqual(self.batch.parameters, None)
+ self.assertEqual(self.batch.parameters, {})
self.assertEqual(self.batch.overrides, {})
self.assertEqual(self.batch.array_properties, {})
self.assertEqual(self.batch.hook.region_name, "eu-west-1")
@@ -121,7 +121,7 @@ class TestAwsBatchOperator(unittest.TestCase):
containerOverrides={},
jobDefinition="hello-world",
arrayProperties={},
- parameters=None,
+ parameters={},
)
self.assertEqual(self.batch.job_id, JOB_ID)
@@ -140,7 +140,7 @@ class TestAwsBatchOperator(unittest.TestCase):
containerOverrides={},
jobDefinition="hello-world",
arrayProperties={},
- parameters=None,
+ parameters={},
)
@mock.patch.object(AwsBatchClientHook, "check_job_success")