This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bd9e8ce  Fix Google Mypy Dataproc errors (#20570)
bd9e8ce is described below

commit bd9e8cef2687de0b047003e159fd8f3f08c6c61f
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Dec 30 15:08:04 2021 +0100

    Fix Google Mypy Dataproc errors (#20570)
    
    Part of #19891
---
 airflow/providers/google/cloud/hooks/dataproc.py   |  26 ++--
 .../providers/google/cloud/operators/dataproc.py   | 155 ++++++++++++---------
 2 files changed, 99 insertions(+), 82 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/dataproc.py 
b/airflow/providers/google/cloud/hooks/dataproc.py
index 282b6f9..b2540f3 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -69,7 +69,7 @@ class DataProcJobBuilder:
         if properties is not None:
             self.job["job"][job_type]["properties"] = properties
 
-    def add_labels(self, labels: dict) -> None:
+    def add_labels(self, labels: Optional[dict] = None) -> None:
         """
         Set labels for Dataproc job.
 
@@ -79,17 +79,17 @@ class DataProcJobBuilder:
         if labels:
             self.job["job"]["labels"].update(labels)
 
-    def add_variables(self, variables: List[str]) -> None:
+    def add_variables(self, variables: Optional[Dict] = None) -> None:
         """
         Set variables for Dataproc job.
 
         :param variables: Variables for the job query.
-        :type variables: List[str]
+        :type variables: Dict
         """
         if variables is not None:
             self.job["job"][self.job_type]["script_variables"] = variables
 
-    def add_args(self, args: List[str]) -> None:
+    def add_args(self, args: Optional[List[str]] = None) -> None:
         """
         Set args for Dataproc job.
 
@@ -99,12 +99,12 @@ class DataProcJobBuilder:
         if args is not None:
             self.job["job"][self.job_type]["args"] = args
 
-    def add_query(self, query: List[str]) -> None:
+    def add_query(self, query: str) -> None:
         """
-        Set query uris for Dataproc job.
+        Set query for Dataproc job.
 
-        :param query: URIs for the job queries.
-        :type query: List[str]
+        :param query: query for the job.
+        :type query: str
         """
         self.job["job"][self.job_type]["query_list"] = {'queries': [query]}
 
@@ -117,7 +117,7 @@ class DataProcJobBuilder:
         """
         self.job["job"][self.job_type]["query_file_uri"] = query_uri
 
-    def add_jar_file_uris(self, jars: List[str]) -> None:
+    def add_jar_file_uris(self, jars: Optional[List[str]] = None) -> None:
         """
         Set jars uris for Dataproc job.
 
@@ -127,7 +127,7 @@ class DataProcJobBuilder:
         if jars is not None:
             self.job["job"][self.job_type]["jar_file_uris"] = jars
 
-    def add_archive_uris(self, archives: List[str]) -> None:
+    def add_archive_uris(self, archives: Optional[List[str]] = None) -> None:
         """
         Set archives uris for Dataproc job.
 
@@ -137,7 +137,7 @@ class DataProcJobBuilder:
         if archives is not None:
             self.job["job"][self.job_type]["archive_uris"] = archives
 
-    def add_file_uris(self, files: List[str]) -> None:
+    def add_file_uris(self, files: Optional[List[str]] = None) -> None:
         """
         Set file uris for Dataproc job.
 
@@ -147,7 +147,7 @@ class DataProcJobBuilder:
         if files is not None:
             self.job["job"][self.job_type]["file_uris"] = files
 
-    def add_python_file_uris(self, pyfiles: List[str]) -> None:
+    def add_python_file_uris(self, pyfiles: Optional[List[str]] = None) -> 
None:
         """
         Set python file uris for Dataproc job.
 
@@ -157,7 +157,7 @@ class DataProcJobBuilder:
         if pyfiles is not None:
             self.job["job"][self.job_type]["python_file_uris"] = pyfiles
 
-    def set_main(self, main_jar: Optional[str], main_class: Optional[str]) -> 
None:
+    def set_main(self, main_jar: Optional[str] = None, main_class: 
Optional[str] = None) -> None:
         """
         Set Dataproc main class.
 
diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 148c65c..fdfabb7 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -1036,23 +1036,30 @@ class DataprocJobBaseOperator(BaseOperator):
         self.impersonation_chain = impersonation_chain
         self.hook = DataprocHook(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain)
         self.project_id = self.hook.project_id if project_id is None else 
project_id
-        self.job_template = None
-        self.job = None
+        self.job_template: Optional[DataProcJobBuilder] = None
+        self.job: Optional[dict] = None
         self.dataproc_job_id = None
         self.asynchronous = asynchronous
 
-    def create_job_template(self):
+    def create_job_template(self) -> DataProcJobBuilder:
         """Initialize `self.job_template` with default values"""
-        self.job_template = DataProcJobBuilder(
+        if self.project_id is None:
+            raise AirflowException(
+                "project id should either be set via project_id "
+                "parameter or retrieved from the connection,"
+            )
+        job_template = DataProcJobBuilder(
             project_id=self.project_id,
             task_id=self.task_id,
             cluster_name=self.cluster_name,
             job_type=self.job_type,
             properties=self.dataproc_properties,
         )
-        self.job_template.set_job_name(self.job_name)
-        self.job_template.add_jar_file_uris(self.dataproc_jars)
-        self.job_template.add_labels(self.labels)
+        job_template.set_job_name(self.job_name)
+        job_template.add_jar_file_uris(self.dataproc_jars)
+        job_template.add_labels(self.labels)
+        self.job_template = job_template
+        return job_template
 
     def _generate_job_template(self) -> str:
         if self.job_template:
@@ -1180,23 +1187,26 @@ class 
DataprocSubmitPigJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
+        job_template = self.create_job_template()
 
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            if self.query_uri is None:
+                raise AirflowException('One of query or query_uri should be 
set here')
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
-
+        job_template = self.create_job_template()
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            if self.query_uri is None:
+                raise AirflowException('One of query or query_uri should be 
set here')
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
 
         super().execute(context)
 
@@ -1256,22 +1266,25 @@ class 
DataprocSubmitHiveJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
+        job_template = self.create_job_template()
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            if self.query_uri is None:
+                raise AirflowException('One of query or query_uri should be 
set here')
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
+        job_template = self.create_job_template()
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            if self.query_uri is None:
+                raise AirflowException('One of query or query_uri should be 
set here')
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
-
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
         super().execute(context)
 
 
@@ -1330,22 +1343,23 @@ class 
DataprocSubmitSparkSqlJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
+        job_template = self.create_job_template()
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
+        job_template = self.create_job_template()
         if self.query is None:
-            self.job_template.add_query_uri(self.query_uri)
+            if self.query_uri is None:
+                raise AirflowException('One of query or query_uri should be 
set here')
+            job_template.add_query_uri(self.query_uri)
         else:
-            self.job_template.add_query(self.query)
-        self.job_template.add_variables(self.variables)
-
+            job_template.add_query(self.query)
+        job_template.add_variables(self.variables)
         super().execute(context)
 
 
@@ -1411,20 +1425,19 @@ class 
DataprocSubmitSparkJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
-        self.job_template.set_main(self.main_jar, self.main_class)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
+        job_template = self.create_job_template()
+        job_template.set_main(self.main_jar, self.main_class)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
-        self.job_template.set_main(self.main_jar, self.main_class)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
-
+        job_template = self.create_job_template()
+        job_template.set_main(self.main_jar, self.main_class)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
         super().execute(context)
 
 
@@ -1490,20 +1503,19 @@ class 
DataprocSubmitHadoopJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
-        self.job_template.set_main(self.main_jar, self.main_class)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
+        job_template = self.create_job_template()
+        job_template.set_main(self.main_jar, self.main_class)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
-        self.job_template.set_main(self.main_jar, self.main_class)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
-
+        job_template = self.create_job_template()
+        job_template.set_main(self.main_jar, self.main_class)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
         super().execute(context)
 
 
@@ -1594,7 +1606,7 @@ class 
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
         Helper method for easier migration to `DataprocSubmitJobOperator`.
         :return: Dict representing Dataproc job
         """
-        self.create_job_template()
+        job_template = self.create_job_template()
         #  Check if the file is local, if that is the case, upload it to a 
bucket
         if os.path.isfile(self.main):
             cluster_info = self.hook.get_cluster(
@@ -1602,16 +1614,16 @@ class 
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
             )
             bucket = cluster_info['config']['config_bucket']
             self.main = f"gs://{bucket}/{self.main}"
-        self.job_template.set_python_main(self.main)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
-        self.job_template.add_python_file_uris(self.pyfiles)
+        job_template.set_python_main(self.main)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
+        job_template.add_python_file_uris(self.pyfiles)
 
         return self._generate_job_template()
 
     def execute(self, context: 'Context'):
-        self.create_job_template()
+        job_template = self.create_job_template()
         #  Check if the file is local, if that is the case, upload it to a 
bucket
         if os.path.isfile(self.main):
             cluster_info = self.hook.get_cluster(
@@ -1620,12 +1632,11 @@ class 
DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):
             bucket = cluster_info['config']['config_bucket']
             self.main = self._upload_file_temp(bucket, self.main)
 
-        self.job_template.set_python_main(self.main)
-        self.job_template.add_args(self.arguments)
-        self.job_template.add_archive_uris(self.archives)
-        self.job_template.add_file_uris(self.files)
-        self.job_template.add_python_file_uris(self.pyfiles)
-
+        job_template.set_python_main(self.main)
+        job_template.add_args(self.arguments)
+        job_template.add_archive_uris(self.archives)
+        job_template.add_file_uris(self.files)
+        job_template.add_python_file_uris(self.pyfiles)
         super().execute(context)
 
 
@@ -2243,6 +2254,8 @@ class DataprocCreateBatchOperator(BaseOperator):
     def execute(self, context: 'Context'):
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
         self.log.info("Creating batch")
+        if self.region is None:
+            raise AirflowException('Region should be set here')
         try:
             self.operation = hook.create_batch(
                 region=self.region,
@@ -2254,10 +2267,14 @@ class DataprocCreateBatchOperator(BaseOperator):
                 timeout=self.timeout,
                 metadata=self.metadata,
             )
+            if self.timeout is None:
+                raise AirflowException('Timeout should be set here')
             result = hook.wait_for_operation(self.timeout, self.operation)
             self.log.info("Batch %s created", self.batch_id)
         except AlreadyExists:
             self.log.info("Batch with given id already exists")
+            if self.batch_id is None:
+                raise AirflowException('Batch Id should be set here')
             result = hook.get_batch(
                 batch_id=self.batch_id,
                 region=self.region,

Reply via email to