This is an automated email from the ASF dual-hosted git repository.
turbaszek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new c8ee455 Refactor DataprocCreateCluster operator to use simpler
interface (#10403)
c8ee455 is described below
commit c8ee4556851c36b3b6e644a7746a49583dd53db1
Author: Tomek Urbaszek <[email protected]>
AuthorDate: Mon Sep 7 12:21:00 2020 +0200
Refactor DataprocCreateCluster operator to use simpler interface (#10403)
DataprocCreateCluster requires now:
- cluster config
- cluster name
- project id
In this way users don't have to pass project_id two times
(in cluster definition and as parameter). The cluster object
is built in create_cluster hook method
---
.../google/cloud/example_dags/example_dataproc.py | 42 ++--
airflow/providers/google/cloud/hooks/dataproc.py | 29 ++-
.../providers/google/cloud/operators/dataproc.py | 216 ++++++++-----------
.../providers/google/cloud/hooks/test_dataproc.py | 45 ++--
.../google/cloud/operators/test_dataproc.py | 239 ++++++++++-----------
.../google/cloud/operators/test_dataproc_system.py | 13 +-
tests/test_utils/gcp_system_helpers.py | 2 +-
7 files changed, 276 insertions(+), 310 deletions(-)
diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py
b/airflow/providers/google/cloud/example_dags/example_dataproc.py
index 494844c..abf94ba 100644
--- a/airflow/providers/google/cloud/example_dags/example_dataproc.py
+++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py
@@ -35,7 +35,7 @@ from airflow.utils.dates import days_ago
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
CLUSTER_NAME = os.environ.get("GCP_DATAPROC_CLUSTER_NAME", "example-project")
REGION = os.environ.get("GCP_LOCATION", "europe-west1")
-ZONE = os.environ.get("GCP_REGION", "europe-west-1b")
+ZONE = os.environ.get("GCP_REGION", "europe-west1-b")
BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests")
OUTPUT_FOLDER = "wordcount"
OUTPUT_PATH = "gs://{}/{}/".format(BUCKET, OUTPUT_FOLDER)
@@ -47,20 +47,16 @@ SPARKR_URI = "gs://{}/{}".format(BUCKET, SPARKR_MAIN)
# Cluster definition
# [START how_to_cloud_dataproc_create_cluster]
-CLUSTER = {
- "project_id": PROJECT_ID,
- "cluster_name": CLUSTER_NAME,
- "config": {
- "master_config": {
- "num_instances": 1,
- "machine_type_uri": "n1-standard-4",
- "disk_config": {"boot_disk_type": "pd-standard",
"boot_disk_size_gb": 1024},
- },
- "worker_config": {
- "num_instances": 2,
- "machine_type_uri": "n1-standard-4",
- "disk_config": {"boot_disk_type": "pd-standard",
"boot_disk_size_gb": 1024},
- },
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
1024},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
1024},
},
}
@@ -69,10 +65,10 @@ CLUSTER = {
# Update options
# [START how_to_cloud_dataproc_updatemask_cluster_operator]
CLUSTER_UPDATE = {
- "config": {"worker_config": {"num_instances": 3},
"secondary_worker_config": {"num_instances": 3},}
+ "config": {"worker_config": {"num_instances": 3},
"secondary_worker_config": {"num_instances": 3}}
}
UPDATE_MASK = {
- "paths": ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances",]
+ "paths": ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances"]
}
# [END how_to_cloud_dataproc_updatemask_cluster_operator]
@@ -141,10 +137,14 @@ HADOOP_JOB = {
}
# [END how_to_cloud_dataproc_hadoop_config]
-with models.DAG("example_gcp_dataproc", start_date=days_ago(1),
schedule_interval=None,) as dag:
+with models.DAG("example_gcp_dataproc", start_date=days_ago(1),
schedule_interval=None) as dag:
# [START how_to_cloud_dataproc_create_cluster_operator]
create_cluster = DataprocCreateClusterOperator(
- task_id="create_cluster", project_id=PROJECT_ID, cluster=CLUSTER,
region=REGION
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
)
# [END how_to_cloud_dataproc_create_cluster_operator]
@@ -164,7 +164,7 @@ with models.DAG("example_gcp_dataproc",
start_date=days_ago(1), schedule_interva
task_id="pig_task", job=PIG_JOB, location=REGION, project_id=PROJECT_ID
)
spark_sql_task = DataprocSubmitJobOperator(
- task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION,
project_id=PROJECT_ID,
+ task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION,
project_id=PROJECT_ID
)
spark_task = DataprocSubmitJobOperator(
@@ -205,7 +205,7 @@ with models.DAG("example_gcp_dataproc",
start_date=days_ago(1), schedule_interva
# [START how_to_cloud_dataproc_delete_cluster_operator]
delete_cluster = DataprocDeleteClusterOperator(
- task_id="delete_cluster", project_id=PROJECT_ID,
cluster_name=CLUSTER_NAME, region=REGION,
+ task_id="delete_cluster", project_id=PROJECT_ID,
cluster_name=CLUSTER_NAME, region=REGION
)
# [END how_to_cloud_dataproc_delete_cluster_operator]
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py
b/airflow/providers/google/cloud/hooks/dataproc.py
index 39aec10..0f9685b 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -63,7 +63,7 @@ class DataProcJobBuilder:
self.job_type = job_type
self.job = {
"job": {
- "reference": {"project_id": project_id, "job_id": name,},
+ "reference": {"project_id": project_id, "job_id": name},
"placement": {"cluster_name": cluster_name},
"labels": {'airflow-version': 'v' +
airflow_version.replace('.', '-').replace('+', '-')},
job_type: {},
@@ -250,8 +250,10 @@ class DataprocHook(GoogleBaseHook):
def create_cluster(
self,
region: str,
- cluster: Union[Dict, Cluster],
project_id: str,
+ cluster_name: str,
+ cluster_config: Union[Dict, Cluster],
+ labels: Optional[Dict[str, str]] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
@@ -264,10 +266,14 @@ class DataprocHook(GoogleBaseHook):
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle
the request.
:type region: str
- :param cluster: Required. The cluster to create.
+ :param cluster_name: Name of the cluster to create
+ :type cluster_name: str
+ :param labels: Labels that will be assigned to created cluster
+ :type labels: Dict[str, str]
+ :param cluster_config: Required. The cluster config to create.
If a dict is provided, it must be of the same form as the protobuf
message
- :class:`~google.cloud.dataproc_v1.types.Cluster`
- :type cluster: Union[Dict, google.cloud.dataproc_v1.types.Cluster]
+ :class:`~google.cloud.dataproc_v1.types.ClusterConfig`
+ :type cluster_config: Union[Dict,
google.cloud.dataproc_v1.types.ClusterConfig]
:param request_id: Optional. A unique id used to identify the request.
If the server receives two
``CreateClusterRequest`` requests with the same id, then the
second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in
the backend is returned.
@@ -281,6 +287,19 @@ class DataprocHook(GoogleBaseHook):
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
+ # Dataproc labels must conform to the following regex:
+ # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
+ # semantic versioning spec: x.y.z).
+ labels = labels or {}
+ labels.update({'airflow-version': 'v' + airflow_version.replace('.',
'-').replace('+', '-')})
+
+ cluster = {
+ "project_id": project_id,
+ "cluster_name": cluster_name,
+ "config": cluster_config,
+ "labels": labels,
+ }
+
client = self.get_cluster_client(location=region)
result = client.create_cluster(
project_id=project_id,
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 1438103..a8b3780 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -46,7 +46,6 @@ from airflow.providers.google.cloud.hooks.dataproc import
DataprocHook, DataProc
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils import timezone
from airflow.utils.decorators import apply_defaults
-from airflow.version import version as airflow_version
# pylint: disable=too-many-instance-attributes
@@ -157,8 +156,7 @@ class ClusterGenerator:
# pylint: disable=too-many-arguments,too-many-locals
def __init__(
self,
- project_id: Optional[str] = None,
- cluster_name: Optional[str] = None,
+ project_id: str,
num_workers: Optional[int] = None,
zone: Optional[str] = None,
network_uri: Optional[str] = None,
@@ -183,8 +181,6 @@ class ClusterGenerator:
worker_disk_type: str = 'pd-standard',
worker_disk_size: int = 1024,
num_preemptible_workers: int = 0,
- labels: Optional[Dict] = None,
- region: Optional[str] = None,
service_account: Optional[str] = None,
service_account_scopes: Optional[List[str]] = None,
idle_delete_ttl: Optional[int] = None,
@@ -194,9 +190,7 @@ class ClusterGenerator:
**kwargs,
) -> None:
- self.cluster_name = cluster_name
self.project_id = project_id
- self.region = region
self.num_masters = num_masters
self.num_workers = num_workers
self.num_preemptible_workers = num_preemptible_workers
@@ -216,7 +210,6 @@ class ClusterGenerator:
self.worker_machine_type = worker_machine_type
self.worker_disk_type = worker_disk_type
self.worker_disk_size = worker_disk_size
- self.labels = labels
self.zone = zone
self.network_uri = network_uri
self.subnetwork_uri = subnetwork_uri
@@ -239,11 +232,11 @@ class ClusterGenerator:
def _get_init_action_timeout(self):
match = re.match(r"^(\d+)([sm])$", self.init_action_timeout)
if match:
+ val = float(match.group(1))
if match.group(2) == "s":
- return self.init_action_timeout
+ return {"seconds": int(val)}
elif match.group(2) == "m":
- val = float(match.group(1))
- return
"{}s".format(int(timedelta(minutes=val).total_seconds()))
+ return {"seconds": int(timedelta(minutes=val).total_seconds())}
raise AirflowException(
"DataprocClusterCreateOperator init_action_timeout"
@@ -255,93 +248,85 @@ class ClusterGenerator:
zone_uri =
'https://www.googleapis.com/compute/v1/projects/{}/zones/{}'.format(
self.project_id, self.zone
)
- cluster_data['config']['gce_cluster_config']['zone_uri'] = zone_uri
+ cluster_data['gce_cluster_config']['zone_uri'] = zone_uri
if self.metadata:
- cluster_data['config']['gce_cluster_config']['metadata'] =
self.metadata
+ cluster_data['gce_cluster_config']['metadata'] = self.metadata
if self.network_uri:
- cluster_data['config']['gce_cluster_config']['network_uri'] =
self.network_uri
+ cluster_data['gce_cluster_config']['network_uri'] =
self.network_uri
if self.subnetwork_uri:
- cluster_data['config']['gce_cluster_config']['subnetwork_uri'] =
self.subnetwork_uri
+ cluster_data['gce_cluster_config']['subnetwork_uri'] =
self.subnetwork_uri
if self.internal_ip_only:
if not self.subnetwork_uri:
raise AirflowException("Set internal_ip_only to true only
when" " you pass a subnetwork_uri.")
- cluster_data['config']['gce_cluster_config']['internal_ip_only'] =
True
+ cluster_data['gce_cluster_config']['internal_ip_only'] = True
if self.tags:
- cluster_data['config']['gce_cluster_config']['tags'] = self.tags
+ cluster_data['gce_cluster_config']['tags'] = self.tags
if self.service_account:
- cluster_data['config']['gce_cluster_config']['service_account'] =
self.service_account
+ cluster_data['gce_cluster_config']['service_account'] =
self.service_account
if self.service_account_scopes:
- cluster_data['config']['gce_cluster_config'][
- 'service_account_scopes'
- ] = self.service_account_scopes
+ cluster_data['gce_cluster_config']['service_account_scopes'] =
self.service_account_scopes
return cluster_data
def _build_lifecycle_config(self, cluster_data):
if self.idle_delete_ttl:
- cluster_data['config']['lifecycle_config']['idle_delete_ttl'] =
"{}s".format(self.idle_delete_ttl)
+ cluster_data['lifecycle_config']['idle_delete_ttl'] = {"seconds":
self.idle_delete_ttl}
if self.auto_delete_time:
utc_auto_delete_time =
timezone.convert_to_utc(self.auto_delete_time)
- cluster_data['config']['lifecycle_config']['auto_delete_time'] =
utc_auto_delete_time.strftime(
+ cluster_data['lifecycle_config']['auto_delete_time'] =
utc_auto_delete_time.strftime(
'%Y-%m-%dT%H:%M:%S.%fZ'
)
elif self.auto_delete_ttl:
- cluster_data['config']['lifecycle_config']['auto_delete_ttl'] =
"{}s".format(self.auto_delete_ttl)
+ cluster_data['lifecycle_config']['auto_delete_ttl'] = {"seconds":
int(self.auto_delete_ttl)}
return cluster_data
def _build_cluster_data(self):
if self.zone:
master_type_uri = (
- "https://www.googleapis.com/compute/v1/projects"
-
f"/{self.project_id}/zones/{self.zone}/machineTypes/{self.master_machine_type}"
+
f"projects/{self.project_id}/zones/{self.zone}/machineTypes/{self.master_machine_type}"
)
worker_type_uri = (
- "https://www.googleapis.com/compute/v1/projects"
-
f"/{self.project_id}/zones/{self.zone}/machineTypes/{self.worker_machine_type}"
+
f"projects/{self.project_id}/zones/{self.zone}/machineTypes/{self.worker_machine_type}"
)
else:
master_type_uri = self.master_machine_type
worker_type_uri = self.worker_machine_type
cluster_data = {
- 'project_id': self.project_id,
- 'cluster_name': self.cluster_name,
- 'config': {
- 'gce_cluster_config': {},
- 'master_config': {
- 'num_instances': self.num_masters,
- 'machine_type_uri': master_type_uri,
- 'disk_config': {
- 'boot_disk_type': self.master_disk_type,
- 'boot_disk_size_gb': self.master_disk_size,
- },
+ 'gce_cluster_config': {},
+ 'master_config': {
+ 'num_instances': self.num_masters,
+ 'machine_type_uri': master_type_uri,
+ 'disk_config': {
+ 'boot_disk_type': self.master_disk_type,
+ 'boot_disk_size_gb': self.master_disk_size,
},
- 'worker_config': {
- 'num_instances': self.num_workers,
- 'machine_type_uri': worker_type_uri,
- 'disk_config': {
- 'boot_disk_type': self.worker_disk_type,
- 'boot_disk_size_gb': self.worker_disk_size,
- },
+ },
+ 'worker_config': {
+ 'num_instances': self.num_workers,
+ 'machine_type_uri': worker_type_uri,
+ 'disk_config': {
+ 'boot_disk_type': self.worker_disk_type,
+ 'boot_disk_size_gb': self.worker_disk_size,
},
- 'secondary_worker_config': {},
- 'software_config': {},
- 'lifecycle_config': {},
- 'encryption_config': {},
- 'autoscaling_config': {},
},
+ 'secondary_worker_config': {},
+ 'software_config': {},
+ 'lifecycle_config': {},
+ 'encryption_config': {},
+ 'autoscaling_config': {},
}
if self.num_preemptible_workers > 0:
- cluster_data['config']['secondary_worker_config'] = {
+ cluster_data['secondary_worker_config'] = {
'num_instances': self.num_preemptible_workers,
'machine_type_uri': worker_type_uri,
'disk_config': {
@@ -351,19 +336,11 @@ class ClusterGenerator:
'is_preemptible': True,
}
- cluster_data['labels'] = self.labels or {}
-
- # Dataproc labels must conform to the following regex:
- # [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
- # semantic versioning spec: x.y.z).
- cluster_data['labels'].update(
- {'airflow-version': 'v' + airflow_version.replace('.',
'-').replace('+', '-')}
- )
if self.storage_bucket:
- cluster_data['config']['config_bucket'] = self.storage_bucket
+ cluster_data['config_bucket'] = self.storage_bucket
if self.image_version:
- cluster_data['config']['software_config']['image_version'] =
self.image_version
+ cluster_data['software_config']['image_version'] =
self.image_version
elif self.custom_image:
project_id = self.custom_image_project_id or self.project_id
@@ -371,9 +348,9 @@ class ClusterGenerator:
'https://www.googleapis.com/compute/beta/projects/'
'{}/global/images/{}'.format(project_id, self.custom_image)
)
- cluster_data['config']['master_config']['image_uri'] =
custom_image_url
+ cluster_data['master_config']['image_uri'] = custom_image_url
if not self.single_node:
- cluster_data['config']['worker_config']['image_uri'] =
custom_image_url
+ cluster_data['worker_config']['image_uri'] = custom_image_url
cluster_data = self._build_gce_cluster_config(cluster_data)
@@ -381,10 +358,10 @@ class ClusterGenerator:
self.properties["dataproc:dataproc.allow.zero.workers"] = "true"
if self.properties:
- cluster_data['config']['software_config']['properties'] =
self.properties
+ cluster_data['software_config']['properties'] = self.properties
if self.optional_components:
- cluster_data['config']['software_config']['optional_components'] =
self.optional_components
+ cluster_data['software_config']['optional_components'] =
self.optional_components
cluster_data = self._build_lifecycle_config(cluster_data)
@@ -393,12 +370,12 @@ class ClusterGenerator:
{'executable_file': uri, 'execution_timeout':
self._get_init_action_timeout()}
for uri in self.init_actions_uris
]
- cluster_data['config']['initialization_actions'] =
init_actions_dict
+ cluster_data['initialization_actions'] = init_actions_dict
if self.customer_managed_key:
- cluster_data['config']['encryption_config'] =
{'gce_pd_kms_key_name': self.customer_managed_key}
+ cluster_data['encryption_config'] = {'gce_pd_kms_key_name':
self.customer_managed_key}
if self.autoscaling_policy:
- cluster_data['config']['autoscaling_config'] = {'policy_uri':
self.autoscaling_policy}
+ cluster_data['autoscaling_config'] = {'policy_uri':
self.autoscaling_policy}
return cluster_data
@@ -435,6 +412,14 @@ class DataprocCreateClusterOperator(BaseOperator):
:param project_id: The ID of the google cloud project in which
to create the cluster. (templated)
:type project_id: str
+ :param cluster_name: Name of the cluster to create
+ :type cluster_name: str
+ :param labels: Labels that will be assigned to created cluster
+ :type labels: Dict[str, str]
+ :param cluster_config: Required. The cluster config to create.
+ If a dict is provided, it must be of the same form as the protobuf
message
+ :class:`~google.cloud.dataproc_v1.types.ClusterConfig`
+ :type cluster_config: Union[Dict,
google.cloud.dataproc_v1.types.ClusterConfig]
:param region: leave as 'global', might become relevant in the future.
(templated)
:type region: str
:parm delete_on_error: If true the cluster will be deleted if created with
ERROR state. Default
@@ -470,7 +455,9 @@ class DataprocCreateClusterOperator(BaseOperator):
template_fields = (
'project_id',
'region',
- 'cluster',
+ 'cluster_config',
+ 'cluster_name',
+ 'labels',
'impersonation_chain',
)
@@ -478,9 +465,11 @@ class DataprocCreateClusterOperator(BaseOperator):
def __init__( # pylint: disable=too-many-arguments
self,
*,
+ cluster_name: str,
region: str = 'global',
project_id: Optional[str] = None,
- cluster: Optional[Dict] = None,
+ cluster_config: Optional[Dict] = None,
+ labels: Optional[Dict] = None,
request_id: Optional[str] = None,
delete_on_error: bool = True,
use_if_exists: bool = True,
@@ -492,10 +481,10 @@ class DataprocCreateClusterOperator(BaseOperator):
**kwargs,
) -> None:
# TODO: remove one day
- if cluster is None:
+ if cluster_config is None:
warnings.warn(
"Passing cluster parameters by keywords to `{}` "
- "will be deprecated. Please provide cluster object using
`cluster` parameter. "
+ "will be deprecated. Please provide cluster_config object
using `cluster_config` parameter. "
"You can use
`airflow.dataproc.ClusterGenerator.generate_cluster` method to "
"obtain cluster object.".format(type(self).__name__),
DeprecationWarning,
@@ -506,9 +495,12 @@ class DataprocCreateClusterOperator(BaseOperator):
del kwargs['params']
# Create cluster object from kwargs
- kwargs['region'] = region
- kwargs['project_id'] = project_id
- cluster = ClusterGenerator(**kwargs).make()
+ if project_id is None:
+ raise AirflowException(
+ "project_id argument is required when building cluster
from keywords parameters"
+ )
+ kwargs["project_id"] = project_id
+ cluster_config = ClusterGenerator(**kwargs).make()
# Remove from kwargs cluster params passed for backward
compatibility
cluster_params =
inspect.signature(ClusterGenerator.__init__).parameters
@@ -518,11 +510,9 @@ class DataprocCreateClusterOperator(BaseOperator):
super().__init__(**kwargs)
- self.cluster = cluster
- try:
- self.cluster_name = cluster['cluster_name']
- except KeyError:
- raise AirflowException("`config` misses `cluster_name` key")
+ self.cluster_config = cluster_config
+ self.cluster_name = cluster_name
+ self.labels = labels
self.project_id = project_id
self.region = region
self.request_id = request_id
@@ -534,11 +524,13 @@ class DataprocCreateClusterOperator(BaseOperator):
self.use_if_exists = use_if_exists
self.impersonation_chain = impersonation_chain
- def _create_cluster(self, hook):
+ def _create_cluster(self, hook: DataprocHook):
operation = hook.create_cluster(
project_id=self.project_id,
region=self.region,
- cluster=self.cluster,
+ cluster_name=self.cluster_name,
+ labels=self.labels,
+ cluster_config=self.cluster_config,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
@@ -550,9 +542,7 @@ class DataprocCreateClusterOperator(BaseOperator):
def _delete_cluster(self, hook):
self.log.info("Deleting the cluster")
- hook.delete_cluster(
- region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id,
- )
+ hook.delete_cluster(region=self.region,
cluster_name=self.cluster_name, project_id=self.project_id)
def _get_cluster(self, hook: DataprocHook):
return hook.get_cluster(
@@ -569,7 +559,7 @@ class DataprocCreateClusterOperator(BaseOperator):
return
self.log.info("Cluster is in ERROR state")
gcs_uri = hook.diagnose_cluster(
- region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id,
+ region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id
)
self.log.info('Diagnostic information for cluster %s available at:
%s', self.cluster_name, gcs_uri)
if self.delete_on_error:
@@ -604,7 +594,7 @@ class DataprocCreateClusterOperator(BaseOperator):
def execute(self, context):
self.log.info('Creating cluster: %s', self.cluster_name)
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
try:
# First try to create a new cluster
cluster = self._create_cluster(hook)
@@ -677,12 +667,7 @@ class DataprocScaleClusterOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""
- template_fields = [
- 'cluster_name',
- 'project_id',
- 'region',
- 'impersonation_chain',
- ]
+ template_fields = ['cluster_name', 'project_id', 'region',
'impersonation_chain']
@apply_defaults
def __init__(
@@ -727,7 +712,7 @@ class DataprocScaleClusterOperator(BaseOperator):
return scale_data
@property
- def _graceful_decommission_timeout_object(self) -> Optional[Dict]:
+ def _graceful_decommission_timeout_object(self) -> Optional[Dict[str,
int]]:
if not self.graceful_decommission_timeout:
return None
@@ -764,7 +749,7 @@ class DataprocScaleClusterOperator(BaseOperator):
scaling_cluster_data = self._build_scale_cluster_data()
update_mask = ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances"]
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
operation = hook.update_cluster(
project_id=self.project_id,
location=self.region,
@@ -846,7 +831,7 @@ class DataprocDeleteClusterOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
def execute(self, context: Dict):
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
self.log.info("Deleting cluster: %s", self.cluster_name)
operation = hook.delete_cluster(
project_id=self.project_id,
@@ -983,7 +968,7 @@ class DataprocJobBaseOperator(BaseOperator):
self.dataproc_job_id = self.job["job"]["reference"]["job_id"]
self.log.info('Submitting %s job %s', self.job_type,
self.dataproc_job_id)
job_object = self.hook.submit_job(
- project_id=self.project_id, job=self.job["job"],
location=self.region,
+ project_id=self.project_id, job=self.job["job"],
location=self.region
)
job_id = job_object.reference.job_id
self.log.info('Job %s submitted successfully.', job_id)
@@ -1060,10 +1045,7 @@ class
DataprocSubmitPigJobOperator(DataprocJobBaseOperator):
'dataproc_properties',
'impersonation_chain',
]
- template_ext = (
- '.pg',
- '.pig',
- )
+ template_ext = ('.pg', '.pig')
ui_color = '#0273d4'
job_type = 'pig_job'
@@ -1138,10 +1120,7 @@ class
DataprocSubmitHiveJobOperator(DataprocJobBaseOperator):
'dataproc_properties',
'impersonation_chain',
]
- template_ext = (
- '.q',
- '.hql',
- )
+ template_ext = ('.q', '.hql')
ui_color = '#0273d4'
job_type = 'hive_job'
@@ -1607,10 +1586,7 @@ class
DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""
- template_fields = [
- 'template_id',
- 'impersonation_chain',
- ]
+ template_fields = ['template_id', 'impersonation_chain']
@apply_defaults
def __init__( # pylint: disable=too-many-arguments
@@ -1644,7 +1620,7 @@ class
DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
def execute(self, context):
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
self.log.info('Instantiating template %s', self.template_id)
operation = hook.instantiate_workflow_template(
project_id=self.project_id,
@@ -1712,10 +1688,7 @@ class
DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):
:type impersonation_chain: Union[str, Sequence[str]]
"""
- template_fields = [
- 'template',
- 'impersonation_chain',
- ]
+ template_fields = ['template', 'impersonation_chain']
@apply_defaults
def __init__(
@@ -1746,7 +1719,7 @@ class
DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):
def execute(self, context):
self.log.info('Instantiating Inline Template')
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
operation = hook.instantiate_inline_workflow_template(
template=self.template,
project_id=self.project_id,
@@ -1802,12 +1775,7 @@ class DataprocSubmitJobOperator(BaseOperator):
:type asynchronous: bool
"""
- template_fields = (
- 'project_id',
- 'location',
- 'job',
- 'impersonation_chain',
- )
+ template_fields = ('project_id', 'location', 'job', 'impersonation_chain')
@apply_defaults
def __init__(
@@ -1839,7 +1807,7 @@ class DataprocSubmitJobOperator(BaseOperator):
def execute(self, context: Dict):
self.log.info("Submitting job")
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
job_object = hook.submit_job(
project_id=self.project_id,
location=self.location,
@@ -1947,7 +1915,7 @@ class DataprocUpdateClusterOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
def execute(self, context: Dict):
- hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,)
+ hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
self.log.info("Updating %s cluster.", self.cluster_name)
operation = hook.update_cluster(
project_id=self.project_id,
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py
b/tests/providers/google/cloud/hooks/test_dataproc.py
index 5205953..a2c1c4d 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -32,8 +32,15 @@ JOB_ID = "test-id"
TASK_ID = "test-task-id"
GCP_LOCATION = "global"
GCP_PROJECT = "test-project"
-CLUSTER = {"test": "test"}
+CLUSTER_CONFIG = {"test": "test"}
+LABELS = {"test": "test"}
CLUSTER_NAME = "cluster-name"
+CLUSTER = {
+ "cluster_name": CLUSTER_NAME,
+ "config": CLUSTER_CONFIG,
+ "labels": LABELS,
+ "project_id": GCP_PROJECT,
+}
PARENT = "parent"
NAME = "name"
@@ -51,9 +58,7 @@ class TestDataprocHook(unittest.TestCase):
self.hook = DataprocHook(gcp_conn_id="test")
@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
- @mock.patch(
- DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock,
- )
+ @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("ClusterControllerClient"))
def test_get_cluster_client(self, mock_client, mock_client_info,
mock_get_credentials):
self.hook.get_cluster_client(location=GCP_LOCATION)
@@ -64,20 +69,16 @@ class TestDataprocHook(unittest.TestCase):
)
@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
- @mock.patch(
- DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock,
- )
+ @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
def test_get_template_client(self, mock_client, mock_client_info,
mock_get_credentials):
_ = self.hook.get_template_client
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value,
+ credentials=mock_get_credentials.return_value,
client_info=mock_client_info.return_value
)
@mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
- @mock.patch(
- DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock,
- )
+ @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
new_callable=mock.PropertyMock)
@mock.patch(DATAPROC_STRING.format("JobControllerClient"))
def test_get_job_client(self, mock_client, mock_client_info,
mock_get_credentials):
self.hook.get_job_client(location=GCP_LOCATION)
@@ -89,7 +90,13 @@ class TestDataprocHook(unittest.TestCase):
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
def test_create_cluster(self, mock_client):
- self.hook.create_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION,
cluster=CLUSTER)
+ self.hook.create_cluster(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CLUSTER_CONFIG,
+ labels=LABELS,
+ )
mock_client.assert_called_once_with(location=GCP_LOCATION)
mock_client.return_value.create_cluster.assert_called_once_with(
project_id=GCP_PROJECT,
@@ -202,13 +209,7 @@ class TestDataprocHook(unittest.TestCase):
)
mock_client.workflow_template_path.assert_called_once_with(GCP_PROJECT,
GCP_LOCATION, template_name)
mock_client.instantiate_workflow_template.assert_called_once_with(
- name=NAME,
- version=None,
- parameters=None,
- request_id=None,
- retry=None,
- timeout=None,
- metadata=None,
+ name=NAME, version=None, parameters=None, request_id=None,
retry=None, timeout=None, metadata=None
)
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
@@ -220,7 +221,7 @@ class TestDataprocHook(unittest.TestCase):
)
mock_client.region_path.assert_called_once_with(GCP_PROJECT,
GCP_LOCATION)
mock_client.instantiate_inline_workflow_template.assert_called_once_with(
- parent=PARENT, template=template, request_id=None, retry=None,
timeout=None, metadata=None,
+ parent=PARENT, template=template, request_id=None, retry=None,
timeout=None, metadata=None
)
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
@@ -230,9 +231,7 @@ class TestDataprocHook(unittest.TestCase):
mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
]
with self.assertRaises(AirflowException):
- self.hook.wait_for_job(
- job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT,
wait_time=0,
- )
+ self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION,
project_id=GCP_PROJECT, wait_time=0)
calls = [
mock.call(location=GCP_LOCATION, job_id=JOB_ID,
project_id=GCP_PROJECT),
mock.call(location=GCP_LOCATION, job_id=JOB_ID,
project_id=GCP_PROJECT),
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 69f8e9d..aa4a1ec 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -56,58 +56,58 @@ GCP_CONN_ID = "test-conn"
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
CLUSTER_NAME = "cluster_name"
-CLUSTER = {
- "project_id": "project_id",
- "cluster_name": CLUSTER_NAME,
- "config": {
- "gce_cluster_config": {
- "zone_uri": "https://www.googleapis.com/compute/v1/projects/"
"project_id/zones/zone",
- "metadata": {"metadata": "data"},
- "network_uri": "network_uri",
- "subnetwork_uri": "subnetwork_uri",
- "internal_ip_only": True,
- "tags": ["tags"],
- "service_account": "service_account",
- "service_account_scopes": ["service_account_scopes"],
- },
- "master_config": {
- "num_instances": 2,
- "machine_type_uri":
"https://www.googleapis.com/compute/v1/projects/"
- "project_id/zones/zone/machineTypes/master_machine_type",
- "disk_config": {"boot_disk_type": "master_disk_type",
"boot_disk_size_gb": 128,},
- "image_uri": "https://www.googleapis.com/compute/beta/projects/"
- "custom_image_project_id/global/images/custom_image",
- },
- "worker_config": {
- "num_instances": 2,
- "machine_type_uri":
"https://www.googleapis.com/compute/v1/projects/"
- "project_id/zones/zone/machineTypes/worker_machine_type",
- "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256,},
- "image_uri": "https://www.googleapis.com/compute/beta/projects/"
- "custom_image_project_id/global/images/custom_image",
- },
- "secondary_worker_config": {
- "num_instances": 4,
- "machine_type_uri":
"https://www.googleapis.com/compute/v1/projects/"
- "project_id/zones/zone/machineTypes/worker_machine_type",
- "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256,},
- "is_preemptible": True,
- },
- "software_config": {
- "properties": {"properties": "data"},
- "optional_components": ["optional_components"],
- },
- "lifecycle_config": {"idle_delete_ttl": "60s", "auto_delete_time":
"2019-09-12T00:00:00.000000Z",},
- "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
- "autoscaling_config": {"policy_uri": "autoscaling_policy"},
- "config_bucket": "storage_bucket",
- "initialization_actions": [{"executable_file": "init_actions_uris",
"execution_timeout": "600s"}],
+CONFIG = {
+ "gce_cluster_config": {
+ "zone_uri": "https://www.googleapis.com/compute/v1/projects/"
"project_id/zones/zone",
+ "metadata": {"metadata": "data"},
+ "network_uri": "network_uri",
+ "subnetwork_uri": "subnetwork_uri",
+ "internal_ip_only": True,
+ "tags": ["tags"],
+ "service_account": "service_account",
+ "service_account_scopes": ["service_account_scopes"],
},
- "labels": {"labels": "data", "airflow-version": AIRFLOW_VERSION},
+ "master_config": {
+ "num_instances": 2,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/master_machine_type",
+ "disk_config": {"boot_disk_type": "master_disk_type",
"boot_disk_size_gb": 128},
+ "image_uri": "https://www.googleapis.com/compute/beta/projects/"
+ "custom_image_project_id/global/images/custom_image",
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+ "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256},
+ "image_uri": "https://www.googleapis.com/compute/beta/projects/"
+ "custom_image_project_id/global/images/custom_image",
+ },
+ "secondary_worker_config": {
+ "num_instances": 4,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+ "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256},
+ "is_preemptible": True,
+ },
+ "software_config": {"properties": {"properties": "data"},
"optional_components": ["optional_components"]},
+ "lifecycle_config": {
+ "idle_delete_ttl": {'seconds': 60},
+ "auto_delete_time": "2019-09-12T00:00:00.000000Z",
+ },
+ "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
+ "autoscaling_config": {"policy_uri": "autoscaling_policy"},
+ "config_bucket": "storage_bucket",
+ "initialization_actions": [
+ {"executable_file": "init_actions_uris", "execution_timeout":
{'seconds': 600}}
+ ],
}
+LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}
+
+LABELS.update({'airflow-version': 'v' + airflow_version.replace('.',
'-').replace('+', '-')})
+
+CLUSTER = {"project_id": "project_id", "cluster_name": CLUSTER_NAME, "config":
CONFIG, "labels": LABELS}
+
UPDATE_MASK = {
- "paths": ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances",]
+ "paths": ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances"]
}
TIMEOUT = 120
@@ -123,18 +123,24 @@ def assert_warning(msg: str, warning: Any):
class TestsClusterGenerator(unittest.TestCase):
def test_image_version(self):
with self.assertRaises(ValueError) as err:
- ClusterGenerator(custom_image="custom_image",
image_version="image_version")
+ ClusterGenerator(
+ custom_image="custom_image",
+ image_version="image_version",
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ )
self.assertIn("custom_image and image_version", str(err))
def test_nodes_number(self):
with self.assertRaises(AssertionError) as err:
- ClusterGenerator(num_workers=0, num_preemptible_workers=0)
+ ClusterGenerator(
+ num_workers=0, num_preemptible_workers=0,
project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
+ )
self.assertIn("num_workers == 0 means single", str(err))
def test_build(self):
generator = ClusterGenerator(
project_id="project_id",
- cluster_name="cluster_name",
num_workers=2,
zone="zone",
network_uri="network_uri",
@@ -158,7 +164,6 @@ class TestsClusterGenerator(unittest.TestCase):
worker_disk_type="worker_disk_type",
worker_disk_size=256,
num_preemptible_workers=4,
- labels={"labels": "data"},
region="region",
service_account="service_account",
service_account_scopes=["service_account_scopes"],
@@ -168,13 +173,13 @@ class TestsClusterGenerator(unittest.TestCase):
customer_managed_key="customer_managed_key",
)
cluster = generator.make()
- self.assertDictEqual(CLUSTER, cluster)
+ self.assertDictEqual(CONFIG, cluster)
class TestDataprocClusterCreateOperator(unittest.TestCase):
def test_deprecation_warning(self):
with self.assertWarns(DeprecationWarning) as warning:
- cluster_operator = DataprocCreateClusterOperator(
+ op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
@@ -183,20 +188,21 @@ class
TestDataprocClusterCreateOperator(unittest.TestCase):
zone="zone",
)
assert_warning("Passing cluster parameters by keywords", warning)
- cluster = cluster_operator.cluster
- self.assertEqual(cluster['project_id'], GCP_PROJECT)
- self.assertEqual(cluster['cluster_name'], "cluster_name")
- self.assertEqual(cluster['config']['worker_config']['num_instances'],
2)
- self.assertIn("zones/zone",
cluster["config"]['master_config']["machine_type_uri"])
+ self.assertEqual(op.project_id, GCP_PROJECT)
+ self.assertEqual(op.cluster_name, "cluster_name")
+ self.assertEqual(op.cluster_config['worker_config']['num_instances'],
2)
+ self.assertIn("zones/zone",
op.cluster_config['master_config']["machine_type_uri"])
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_LOCATION,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
request_id=REQUEST_ID,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
@@ -205,13 +211,13 @@ class
TestDataprocClusterCreateOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
@@ -226,7 +232,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
@@ -235,13 +243,13 @@ class
TestDataprocClusterCreateOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
@@ -264,7 +272,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
@@ -286,7 +296,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
delete_on_error=True,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
@@ -298,10 +310,10 @@ class
TestDataprocClusterCreateOperator(unittest.TestCase):
op.execute(context={})
mock_hook.return_value.diagnose_cluster.assert_called_once_with(
- region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
+ region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME
)
mock_hook.return_value.delete_cluster.assert_called_once_with(
- region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
+ region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME
)
@mock.patch(DATAPROC_PATH.format("exponential_sleep_generator"))
@@ -327,7 +339,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
task_id=TASK_ID,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
- cluster=CLUSTER,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
delete_on_error=True,
gcp_conn_id=GCP_CONN_ID,
)
@@ -338,7 +352,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase):
mock_get_cluster.assert_has_calls(calls)
mock_create_cluster.assert_has_calls(calls)
mock_hook.return_value.diagnose_cluster.assert_called_once_with(
- region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
+ region=GCP_LOCATION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME
)
@@ -351,10 +365,7 @@ class TestDataprocClusterScaleOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
cluster_update = {
- "config": {
- "worker_config": {"num_instances": 3},
- "secondary_worker_config": {"num_instances": 4},
- }
+ "config": {"worker_config": {"num_instances": 3},
"secondary_worker_config": {"num_instances": 4}}
}
op = DataprocScaleClusterOperator(
@@ -370,9 +381,7 @@ class TestDataprocClusterScaleOperator(unittest.TestCase):
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.update_cluster.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
@@ -399,9 +408,7 @@ class TestDataprocClusterDeleteOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.delete_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
@@ -436,9 +443,7 @@ class TestDataprocSubmitJobOperator(unittest.TestCase):
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
@@ -508,9 +513,7 @@ class TestDataprocUpdateClusterOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.update_cluster.assert_called_once_with(
location=GCP_LOCATION,
project_id=GCP_PROJECT,
@@ -547,9 +550,7 @@ class
TestDataprocWorkflowTemplateInstantiateOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.instantiate_workflow_template.assert_called_once_with(
template_name=template_id,
location=GCP_LOCATION,
@@ -581,9 +582,7 @@ class
TestDataprocWorkflowTemplateInstantiateInlineOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with(
template=template,
location=GCP_LOCATION,
@@ -600,7 +599,7 @@ class TestDataProcHiveOperator(unittest.TestCase):
variables = {"key": "value"}
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hive_job": {"query_list": {"queries": [query]}, "script_variables":
variables},
@@ -609,9 +608,7 @@ class TestDataProcHiveOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
- DataprocSubmitHiveJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION, query="query",
- )
+ DataprocSubmitHiveJobOperator(task_id=TASK_ID,
region=GCP_LOCATION, query="query")
assert_warning("DataprocSubmitJobOperator", warning)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@@ -631,9 +628,7 @@ class TestDataProcHiveOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
@@ -663,7 +658,7 @@ class TestDataProcPigOperator(unittest.TestCase):
variables = {"key": "value"}
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pig_job": {"query_list": {"queries": [query]}, "script_variables":
variables},
@@ -672,9 +667,7 @@ class TestDataProcPigOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
- DataprocSubmitPigJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION, query="query",
- )
+ DataprocSubmitPigJobOperator(task_id=TASK_ID, region=GCP_LOCATION,
query="query")
assert_warning("DataprocSubmitJobOperator", warning)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@@ -694,9 +687,7 @@ class TestDataProcPigOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
@@ -726,18 +717,16 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
variables = {"key": "value"}
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
- "spark_sql_job": {"query_list": {"queries": [query]},
"script_variables": variables,},
+ "spark_sql_job": {"query_list": {"queries": [query]},
"script_variables": variables},
}
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
- DataprocSubmitSparkSqlJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION, query="query",
- )
+ DataprocSubmitSparkSqlJobOperator(task_id=TASK_ID,
region=GCP_LOCATION, query="query")
assert_warning("DataprocSubmitJobOperator", warning)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
@@ -757,9 +746,7 @@ class TestDataProcSparkSqlOperator(unittest.TestCase):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
- mock_hook.assert_called_once_with(
- gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN,
- )
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
@@ -789,7 +776,7 @@ class TestDataProcSparkOperator(unittest.TestCase):
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_job": {"jar_file_uris": jars, "main_class": main_class},
@@ -799,7 +786,7 @@ class TestDataProcSparkOperator(unittest.TestCase):
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
DataprocSubmitSparkJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION,
main_class=self.main_class, dataproc_jars=self.jars,
+ task_id=TASK_ID, region=GCP_LOCATION,
main_class=self.main_class, dataproc_jars=self.jars
)
assert_warning("DataprocSubmitJobOperator", warning)
@@ -826,7 +813,7 @@ class TestDataProcHadoopOperator(unittest.TestCase):
jar = "file:///usr/lib/spark/examples/jars/spark-examples.jar"
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hadoop_job": {"main_jar_file_uri": jar, "args": args},
@@ -836,7 +823,7 @@ class TestDataProcHadoopOperator(unittest.TestCase):
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
DataprocSubmitHadoopJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION, main_jar=self.jar,
arguments=self.args,
+ task_id=TASK_ID, region=GCP_LOCATION, main_jar=self.jar,
arguments=self.args
)
assert_warning("DataprocSubmitJobOperator", warning)
@@ -862,7 +849,7 @@ class TestDataProcPySparkOperator(unittest.TestCase):
uri = "gs://{}/{}"
job_id = "uuid_id"
job = {
- "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id,},
+ "reference": {"project_id": GCP_PROJECT, "job_id":
"{{task.task_id}}_{{ds_nodash}}_" + job_id},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pyspark_job": {"main_python_file_uri": uri},
@@ -871,9 +858,7 @@ class TestDataProcPySparkOperator(unittest.TestCase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_deprecation_warning(self, mock_hook):
with self.assertWarns(DeprecationWarning) as warning:
- DataprocSubmitPySparkJobOperator(
- task_id=TASK_ID, region=GCP_LOCATION, main=self.uri,
- )
+ DataprocSubmitPySparkJobOperator(task_id=TASK_ID,
region=GCP_LOCATION, main=self.uri)
assert_warning("DataprocSubmitJobOperator", warning)
@mock.patch(DATAPROC_PATH.format("uuid.uuid4"))
diff --git a/tests/providers/google/cloud/operators/test_dataproc_system.py
b/tests/providers/google/cloud/operators/test_dataproc_system.py
index 816ee48..48d4fdf 100644
--- a/tests/providers/google/cloud/operators/test_dataproc_system.py
+++ b/tests/providers/google/cloud/operators/test_dataproc_system.py
@@ -15,18 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import os
-
import pytest
+from airflow.providers.google.cloud.example_dags.example_dataproc import
PYSPARK_MAIN, BUCKET, SPARKR_MAIN
from tests.providers.google.cloud.utils.gcp_authenticator import
GCP_DATAPROC_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER,
GoogleSystemTest, provide_gcp_context
-BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests")
-PYSPARK_MAIN = os.environ.get("PYSPARK_MAIN", "hello_world.py")
-PYSPARK_URI = "gs://{}/{}".format(BUCKET, PYSPARK_MAIN)
-SPARKR_MAIN = os.environ.get("SPARKR_MAIN", "hello_world.R")
-SPARKR_URI = "gs://{}/{}".format(BUCKET, SPARKR_MAIN)
+GCS_URI = f"gs://{BUCKET}"
pyspark_file = """
#!/usr/bin/python
@@ -57,8 +52,8 @@ class DataprocExampleDagsTest(GoogleSystemTest):
def setUp(self):
super().setUp()
self.create_gcs_bucket(BUCKET)
- self.upload_content_to_gcs(lines=pyspark_file, bucket=PYSPARK_URI,
filename=PYSPARK_MAIN)
- self.upload_content_to_gcs(lines=sparkr_file, bucket=SPARKR_URI,
filename=SPARKR_MAIN)
+ self.upload_content_to_gcs(lines=pyspark_file, bucket=GCS_URI,
filename=PYSPARK_MAIN)
+ self.upload_content_to_gcs(lines=sparkr_file, bucket=GCS_URI,
filename=SPARKR_MAIN)
@provide_gcp_context(GCP_DATAPROC_KEY)
def tearDown(self):
diff --git a/tests/test_utils/gcp_system_helpers.py
b/tests/test_utils/gcp_system_helpers.py
index cb44aff..979ad50 100644
--- a/tests/test_utils/gcp_system_helpers.py
+++ b/tests/test_utils/gcp_system_helpers.py
@@ -153,7 +153,7 @@ class GoogleSystemTest(SystemTest):
with open(tmp_path, "w") as file:
file.writelines(lines)
file.flush()
- os.chmod(tmp_path, 555)
+ os.chmod(tmp_path, 777)
cls.upload_to_gcs(tmp_path, bucket_name)
@classmethod