This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bd9e8ce Fix Google Mypy Dataproc errors (#20570)
bd9e8ce is described below
commit bd9e8cef2687de0b047003e159fd8f3f08c6c61f
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Dec 30 15:08:04 2021 +0100
Fix Google Mypy Dataproc errors (#20570)
Part of #19891
---
airflow/providers/google/cloud/hooks/dataproc.py | 26 ++--
.../providers/google/cloud/operators/dataproc.py | 155 ++++++++++++---------
2 files changed, 99 insertions(+), 82 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py
b/airflow/providers/google/cloud/hooks/dataproc.py
index 282b6f9..b2540f3 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -69,7 +69,7 @@ class DataProcJobBuilder:
if properties is not None:
self.job["job"][job_type]["properties"] = properties
- def add_labels(self, labels: dict) -> None:
+ def add_labels(self, labels: Optional[dict] = None) -> None:
"""
Set labels for Dataproc job.
@@ -79,17 +79,17 @@ class DataProcJobBuilder:
if labels:
self.job["job"]["labels"].update(labels)
- def add_variables(self, variables: List[str]) -> None:
+ def add_variables(self, variables: Optional[Dict] = None) -> None:
"""
Set variables for Dataproc job.
:param variables: Variables for the job query.
- :type variables: List[str]
+ :type variables: Dict
"""
if variables is not None:
self.job["job"][self.job_type]["script_variables"] = variables
- def add_args(self, args: List[str]) -> None:
+ def add_args(self, args: Optional[List[str]] = None) -> None:
"""
Set args for Dataproc job.
@@ -99,12 +99,12 @@ class DataProcJobBuilder:
if args is not None:
self.job["job"][self.job_type]["args"] = args
- def add_query(self, query: List[str]) -> None:
+ def add_query(self, query: str) -> None:
"""
- Set query uris for Dataproc job.
+ Set query for Dataproc job.
- :param query: URIs for the job queries.
- :type query: List[str]
+ :param query: query for the job.
+ :type query: str
"""
self.job["job"][self.job_type]["query_list"] = {'queries': [query]}
@@ -117,7 +117,7 @@ class DataProcJobBuilder:
"""
self.job["job"][self.job_type]["query_file_uri"] = query_uri
- def add_jar_file_uris(self, jars: List[str]) -> None:
+ def add_jar_file_uris(self, jars: Optional[List[str]] = None) -> None:
"""
Set jars uris for Dataproc job.
@@ -127,7 +127,7 @@ class DataProcJobBuilder:
if jars is not None:
self.job["job"][self.job_type]["jar_file_uris"] = jars
- def add_archive_uris(self, archives: List[str]) -> None:
+ def add_archive_uris(self, archives: Optional[List[str]] = None) -> None:
"""
Set archives uris for Dataproc job.
@@ -137,7 +137,7 @@ class DataProcJobBuilder:
if archives is not None:
self.job["job"][self.job_type]["archive_uris"] = archives
- def add_file_uris(self, files: List[str]) -> None:
+ def add_file_uris(self, files: Optional[List[str]] = None) -> None:
"""
Set file uris for Dataproc job.
@@ -147,7 +147,7 @@ class DataProcJobBuilder:
if files is not None:
self.job["job"][self.job_type]["file_uris"] = files
- def add_python_file_uris(self, pyfiles: List[str]) -> None:
+ def add_python_file_uris(self, pyfiles: Optional[List[str]] = None) ->
None:
"""
Set python file uris for Dataproc job.
@@ -157,7 +157,7 @@ class DataProcJobBuilder:
if pyfiles is not None:
self.job["job"][self.job_type]["python_file_uris"] = pyfiles
- def set_main(self, main_jar: Optional[str], main_class: Optional[str]) ->
None:
+ def set_main(self, main_jar: Optional[str] = None, main_class:
Optional[str] = None) -> None:
"""
Set Dataproc main class.
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 148c65c..fdfabb7 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1036,23 +1036,30 @@ class DataprocJobBaseOperator(BaseOperator):
self.impersonation_chain = impersonation_chain
self.hook = DataprocHook(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain)
self.project_id = self.hook.project_id if project_id is None else
project_id
- self.job_template = None
- self.job = None
+ self.job_template: Optional[DataProcJobBuilder] = None
+ self.job: Optional[dict] = None
self.dataproc_job_id = None
self.asynchronous = asynchronous
- def create_job_template(self):
+ def create_job_template(self) -> DataProcJobBuilder:
"""Initialize `self.job_template` with default values"""
- self.job_template = DataProcJobBuilder(
+ if self.project_id is None:
+ raise AirflowException(
+ "project id should either be set via project_id "
+ "parameter or retrieved from the connection,"
+ )
+ job_template = DataProcJobBuilder(
project_id=self.project_id,
task_id=self.task_id,
cluster_name=self.cluster_name,
job_type=self.job_type,
properties=self.dataproc_properties,
)
- self.job_template.set_job_name(self.job_name)
- self.job_template.add_jar_file_uris(self.dataproc_jars)
- self.job_template.add_labels(self.labels)
+ job_template.set_job_name(self.job_name)
+ job_template.add_jar_file_uris(self.dataproc_jars)
+ job_template.add_labels(self.labels)
+ self.job_template = job_template
+ return job_template
def _generate_job_template(self) -> str:
if self.job_template:
@@ -1180,23 +1187,26 @@ class
DataprocSubmitPigJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ if self.query_uri is None:
+ raise AirflowException('One of query or query_uri should be
set here')
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
-
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ if self.query_uri is None:
+ raise AirflowException('One of query or query_uri should be
set here')
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
super().execute(context)
@@ -1256,22 +1266,25 @@ class
DataprocSubmitHiveJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ if self.query_uri is None:
+ raise AirflowException('One of query or query_uri should be
set here')
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ if self.query_uri is None:
+ raise AirflowException('One of query or query_uri should be
set here')
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
-
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
super().execute(context)
@@ -1330,22 +1343,23 @@ class
DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
+ job_template = self.create_job_template()
if self.query is None:
- self.job_template.add_query_uri(self.query_uri)
+ if self.query_uri is None:
+ raise AirflowException('One of query or query_uri should be
set here')
+ job_template.add_query_uri(self.query_uri)
else:
- self.job_template.add_query(self.query)
- self.job_template.add_variables(self.variables)
-
+ job_template.add_query(self.query)
+ job_template.add_variables(self.variables)
super().execute(context)
@@ -1411,20 +1425,19 @@ class
DataprocSubmitSparkJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
- self.job_template.set_main(self.main_jar, self.main_class)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
+ job_template = self.create_job_template()
+ job_template.set_main(self.main_jar, self.main_class)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
- self.job_template.set_main(self.main_jar, self.main_class)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
-
+ job_template = self.create_job_template()
+ job_template.set_main(self.main_jar, self.main_class)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
super().execute(context)
@@ -1490,20 +1503,19 @@ class
DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
- self.job_template.set_main(self.main_jar, self.main_class)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
+ job_template = self.create_job_template()
+ job_template.set_main(self.main_jar, self.main_class)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
- self.job_template.set_main(self.main_jar, self.main_class)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
-
+ job_template = self.create_job_template()
+ job_template.set_main(self.main_jar, self.main_class)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
super().execute(context)
@@ -1594,7 +1606,7 @@ class
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
Helper method for easier migration to `DataprocSubmitJobOperator`.
:return: Dict representing Dataproc job
"""
- self.create_job_template()
+ job_template = self.create_job_template()
# Check if the file is local, if that is the case, upload it to a
bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
@@ -1602,16 +1614,16 @@ class
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
)
bucket = cluster_info['config']['config_bucket']
self.main = f"gs://{bucket}/{self.main}"
- self.job_template.set_python_main(self.main)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
- self.job_template.add_python_file_uris(self.pyfiles)
+ job_template.set_python_main(self.main)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
+ job_template.add_python_file_uris(self.pyfiles)
return self._generate_job_template()
def execute(self, context: 'Context'):
- self.create_job_template()
+ job_template = self.create_job_template()
# Check if the file is local, if that is the case, upload it to a
bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
@@ -1620,12 +1632,11 @@ class
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
bucket = cluster_info['config']['config_bucket']
self.main = self._upload_file_temp(bucket, self.main)
- self.job_template.set_python_main(self.main)
- self.job_template.add_args(self.arguments)
- self.job_template.add_archive_uris(self.archives)
- self.job_template.add_file_uris(self.files)
- self.job_template.add_python_file_uris(self.pyfiles)
-
+ job_template.set_python_main(self.main)
+ job_template.add_args(self.arguments)
+ job_template.add_archive_uris(self.archives)
+ job_template.add_file_uris(self.files)
+ job_template.add_python_file_uris(self.pyfiles)
super().execute(context)
@@ -2243,6 +2254,8 @@ class DataprocCreateBatchOperator(BaseOperator):
def execute(self, context: 'Context'):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
self.log.info("Creating batch")
+ if self.region is None:
+ raise AirflowException('Region should be set here')
try:
self.operation = hook.create_batch(
region=self.region,
@@ -2254,10 +2267,14 @@ class DataprocCreateBatchOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ if self.timeout is None:
+ raise AirflowException('Timeout should be set here')
result = hook.wait_for_operation(self.timeout, self.operation)
self.log.info("Batch %s created", self.batch_id)
except AlreadyExists:
self.log.info("Batch with given id already exists")
+ if self.batch_id is None:
+ raise AirflowException('Batch Id should be set here')
result = hook.get_batch(
batch_id=self.batch_id,
region=self.region,