This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b99cb7ce108 Create operators for working with Topics for GCP Apache
Kafka (#46865)
b99cb7ce108 is described below
commit b99cb7ce108ea013f755f770928a00d9e8d0944d
Author: Maksim <[email protected]>
AuthorDate: Thu Feb 20 12:25:32 2025 -0800
Create operators for working with Topics for GCP Apache Kafka (#46865)
---
.../google/docs/operators/cloud/managed_kafka.rst | 48 +++
providers/google/provider.yaml | 1 +
.../providers/google/cloud/hooks/managed_kafka.py | 198 +++++++++++-
.../providers/google/cloud/links/managed_kafka.py | 29 ++
.../google/cloud/operators/managed_kafka.py | 337 +++++++++++++++++++++
.../airflow/providers/google/get_provider_info.py | 1 +
.../managed_kafka/example_managed_kafka_topic.py | 172 +++++++++++
.../unit/google/cloud/hooks/test_managed_kafka.py | 246 +++++++++++++++
.../unit/google/cloud/links/test_managed_kafka.py | 38 +++
.../google/cloud/operators/test_managed_kafka.py | 172 +++++++++++
10 files changed, 1240 insertions(+), 2 deletions(-)
diff --git a/providers/google/docs/operators/cloud/managed_kafka.rst
b/providers/google/docs/operators/cloud/managed_kafka.rst
index 0016076f183..a81f81592ee 100644
--- a/providers/google/docs/operators/cloud/managed_kafka.rst
+++ b/providers/google/docs/operators/cloud/managed_kafka.rst
@@ -69,6 +69,54 @@ To update cluster you can use
:start-after: [START how_to_cloud_managed_kafka_update_cluster_operator]
:end-before: [END how_to_cloud_managed_kafka_update_cluster_operator]
+Interacting with Apache Kafka Topics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To create an Apache Kafka topic you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaCreateTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_create_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_create_topic_operator]
+
+To delete topic you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaDeleteTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_delete_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_delete_topic_operator]
+
+To get topic you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaGetTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_get_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_get_topic_operator]
+
+To get a list of topics you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaListTopicsOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_list_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_list_topic_operator]
+
+To update topic you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaUpdateTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_update_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_update_topic_operator]
+
Reference
^^^^^^^^^
diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml
index 5575cfc747f..5a67f96bb33 100644
--- a/providers/google/provider.yaml
+++ b/providers/google/provider.yaml
@@ -1229,6 +1229,7 @@ extra-links:
-
airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink
-
airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink
+ - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink
secrets-backends:
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
index 48768666f8f..aec8d92f997 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
@@ -27,12 +27,12 @@ from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, types
+from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, Topic,
types
if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
- from google.cloud.managedkafka_v1.services.managed_kafka.pagers import
ListClustersPager
+ from google.cloud.managedkafka_v1.services.managed_kafka.pagers import
ListClustersPager, ListTopicsPager
from google.protobuf.field_mask_pb2 import FieldMask
@@ -286,3 +286,197 @@ class ManagedKafkaHook(GoogleBaseHook):
metadata=metadata,
)
return operation
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_topic(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ topic_id: str,
+ topic: types.Topic | dict,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.Topic:
+ """
+ Create a new topic in a given project and location.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster in which to create
the topic.
+ :param topic_id: Required. The ID to use for the topic, which will
become the final component of the
+ topic's name.
+ :param topic: Required. Configuration of the topic to create.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ parent = client.cluster_path(project_id, location, cluster_id)
+
+ result = client.create_topic(
+ request={
+ "parent": parent,
+ "topic_id": topic_id,
+ "topic": topic,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_topics(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ page_size: int | None = None,
+ page_token: str | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListTopicsPager:
+ """
+ List the topics in a given cluster.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topics are to
be listed.
+ :param page_size: Optional. The maximum number of topics to return.
The service may return fewer than
+ this value. If unset or zero, all topics for the parent is
returned.
+ :param page_token: Optional. A page token, received from a previous
``ListTopics`` call. Provide this
+ to retrieve the subsequent page. When paginating, all other
parameters provided to ``ListTopics``
+ must match the call that provided the page token.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ parent = client.cluster_path(project_id, location, cluster_id)
+
+ result = client.list_topics(
+ request={
+ "parent": parent,
+ "page_size": page_size,
+ "page_token": page_token,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_topic(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ topic_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.Topic:
+ """
+ Return the properties of a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to
be returned.
+ :param topic_id: Required. The ID of the topic whose configuration to
return.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ name = client.topic_path(project_id, location, cluster_id, topic_id)
+
+ result = client.get_topic(
+ request={
+ "name": name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def update_topic(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ topic_id: str,
+ topic: types.Topic | dict,
+ update_mask: FieldMask | dict,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.Topic:
+ """
+ Update the properties of a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to
be updated.
+ :param topic_id: Required. The ID of the topic whose configuration to
update.
+ :param topic: Required. The topic to update. Its ``name`` field must
be populated.
+ :param update_mask: Required. Field mask is used to specify the fields
to be overwritten in the Topic
+ resource by the update. The fields specified in the update_mask
are relative to the resource, not
+ the full request. A field will be overwritten if it is in the mask.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ _topic = deepcopy(topic) if isinstance(topic, dict) else
Topic.to_dict(topic)
+ _topic["name"] = client.topic_path(project_id, location, cluster_id,
topic_id)
+
+ result = client.update_topic(
+ request={
+ "update_mask": update_mask,
+ "topic": _topic,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_topic(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ topic_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> None:
+ """
+ Delete a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to
be deleted.
+ :param topic_id: Required. The ID of the topic to delete.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ name = client.topic_path(project_id, location, cluster_id, topic_id)
+
+ client.delete_topic(
+ request={
+ "name": name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
diff --git
a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
index 00c626b3814..0aafe2f202d 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
@@ -28,6 +28,9 @@ MANAGED_KAFKA_CLUSTER_LINK = (
MANAGED_KAFKA_BASE_LINK +
"/{location}/clusters/{cluster_id}?project={project_id}"
)
MANAGED_KAFKA_CLUSTER_LIST_LINK = MANAGED_KAFKA_BASE_LINK +
"/clusters?project={project_id}"
+MANAGED_KAFKA_TOPIC_LINK = (
+ MANAGED_KAFKA_BASE_LINK +
"/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}"
+)
class ApacheKafkaClusterLink(BaseGoogleLink):
@@ -73,3 +76,29 @@ class ApacheKafkaClusterListLink(BaseGoogleLink):
"project_id": task_instance.project_id,
},
)
+
+
+class ApacheKafkaTopicLink(BaseGoogleLink):
+ """Helper class for constructing Apache Kafka Topic link."""
+
+ name = "Apache Kafka Topic"
+ key = "topic_conf"
+ format_str = MANAGED_KAFKA_TOPIC_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ cluster_id: str,
+ topic_id: str,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=ApacheKafkaTopicLink.key,
+ value={
+ "location": task_instance.location,
+ "cluster_id": cluster_id,
+ "topic_id": topic_id,
+ "project_id": task_instance.project_id,
+ },
+ )
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
index ebf03856216..2afb30fede9 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
@@ -28,6 +28,7 @@ from airflow.providers.google.cloud.hooks.managed_kafka
import ManagedKafkaHook
from airflow.providers.google.cloud.links.managed_kafka import (
ApacheKafkaClusterLink,
ApacheKafkaClusterListLink,
+ ApacheKafkaTopicLink,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
from google.api_core.exceptions import AlreadyExists, NotFound
@@ -449,3 +450,339 @@ class
ManagedKafkaDeleteClusterOperator(ManagedKafkaBaseOperator):
except NotFound as not_found_err:
self.log.info("The Apache Kafka cluster ID %s does not exist.",
self.cluster_id)
raise AirflowException(not_found_err)
+
+
+class ManagedKafkaCreateTopicOperator(ManagedKafkaBaseOperator):
+ """
+ Create a new topic in a given project and location.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster in which to create the
topic.
+ :param topic_id: Required. The ID to use for the topic, which will become
the final component of the
+ topic's name.
+ :param topic: Required. Configuration of the topic to create.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "topic_id", "topic"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+ operator_extra_links = (ApacheKafkaTopicLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ topic_id: str,
+ topic: types.Topic | dict,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.topic_id = topic_id
+ self.topic = topic
+
+ def execute(self, context: Context):
+ self.log.info("Creating an Apache Kafka topic.")
+ ApacheKafkaTopicLink.persist(
+ context=context,
+ task_instance=self,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ )
+ try:
+ topic_obj = self.hook.create_topic(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ topic=self.topic,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Apache Kafka topic for %s cluster was created.",
self.cluster_id)
+ return types.Topic.to_dict(topic_obj)
+ except AlreadyExists:
+ self.log.info("Apache Kafka topic %s already exists.",
self.topic_id)
+ topic_obj = self.hook.get_topic(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ return types.Topic.to_dict(topic_obj)
+
+
+class ManagedKafkaListTopicsOperator(ManagedKafkaBaseOperator):
+ """
+ List the topics in a given cluster.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topics are to be
listed.
+ :param page_size: Optional. The maximum number of topics to return. The
service may return fewer than
+ this value. If unset or zero, all topics for the parent is returned.
+ :param page_token: Optional. A page token, received from a previous
``ListTopics`` call. Provide this
+ to retrieve the subsequent page. When paginating, all other parameters
provided to ``ListTopics``
+ must match the call that provided the page token.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple({"cluster_id"} |
set(ManagedKafkaBaseOperator.template_fields))
+ operator_extra_links = (ApacheKafkaClusterLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ page_size: int | None = None,
+ page_token: str | None = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.page_size = page_size
+ self.page_token = page_token
+
+ def execute(self, context: Context):
+ ApacheKafkaClusterLink.persist(context=context, task_instance=self,
cluster_id=self.cluster_id)
+ self.log.info("Listing Topics for cluster %s.", self.cluster_id)
+ try:
+ topic_list_pager = self.hook.list_topics(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ page_size=self.page_size,
+ page_token=self.page_token,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.xcom_push(
+ context=context,
+ key="topic_page",
+
value=types.ListTopicsResponse.to_dict(topic_list_pager._response),
+ )
+ except Exception as error:
+ raise AirflowException(error)
+ return [types.Topic.to_dict(topic) for topic in topic_list_pager]
+
+
+class ManagedKafkaGetTopicOperator(ManagedKafkaBaseOperator):
+ """
+ Return the properties of a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to be
returned.
+ :param topic_id: Required. The ID of the topic whose configuration to
return.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "topic_id"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+ operator_extra_links = (ApacheKafkaTopicLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ topic_id: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.topic_id = topic_id
+
+ def execute(self, context: Context):
+ ApacheKafkaTopicLink.persist(
+ context=context,
+ task_instance=self,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ )
+ self.log.info("Getting Topic: %s", self.topic_id)
+ try:
+ topic = self.hook.get_topic(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("The topic %s from cluster %s was retrieved.",
self.topic_id, self.cluster_id)
+ return types.Topic.to_dict(topic)
+ except NotFound as not_found_err:
+ self.log.info("The Topic %s does not exist.", self.topic_id)
+ raise AirflowException(not_found_err)
+
+
+class ManagedKafkaUpdateTopicOperator(ManagedKafkaBaseOperator):
+ """
+ Update the properties of a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to be
updated.
+ :param topic_id: Required. The ID of the topic whose configuration to
update.
+ :param topic: Required. The topic to update. Its ``name`` field must be
populated.
+ :param update_mask: Required. Field mask is used to specify the fields to
be overwritten in the Topic
+ resource by the update. The fields specified in the update_mask are
relative to the resource, not
+ the full request. A field will be overwritten if it is in the mask.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "topic_id", "topic", "update_mask"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+ operator_extra_links = (ApacheKafkaTopicLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ topic_id: str,
+ topic: types.Topic | dict,
+ update_mask: FieldMask | dict,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.topic_id = topic_id
+ self.topic = topic
+ self.update_mask = update_mask
+
+ def execute(self, context: Context):
+ ApacheKafkaTopicLink.persist(
+ context=context,
+ task_instance=self,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ )
+ self.log.info("Updating an Apache Kafka topic.")
+ try:
+ topic_obj = self.hook.update_topic(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ topic=self.topic,
+ update_mask=self.update_mask,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Apache Kafka topic %s was updated.", self.topic_id)
+ return types.Topic.to_dict(topic_obj)
+ except NotFound as not_found_err:
+ self.log.info("The Topic %s does not exist.", self.topic_id)
+ raise AirflowException(not_found_err)
+ except Exception as error:
+ raise AirflowException(error)
+
+
+class ManagedKafkaDeleteTopicOperator(ManagedKafkaBaseOperator):
+ """
+ Delete a single topic.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to be
deleted.
+ :param topic_id: Required. The ID of the topic to delete.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "topic_id"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+
+ def __init__(
+ self,
+ cluster_id: str,
+ topic_id: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.topic_id = topic_id
+
+ def execute(self, context: Context):
+ try:
+ self.log.info("Deleting Apache Kafka topic: %s", self.topic_id)
+ self.hook.delete_topic(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ topic_id=self.topic_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Apache Kafka topic was deleted.")
+ except NotFound as not_found_err:
+ self.log.info("The Apache Kafka topic ID %s does not exist.",
self.topic_id)
+ raise AirflowException(not_found_err)
diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py
b/providers/google/src/airflow/providers/google/get_provider_info.py
index 64e191a3280..45316c69a39 100644
--- a/providers/google/src/airflow/providers/google/get_provider_info.py
+++ b/providers/google/src/airflow/providers/google/get_provider_info.py
@@ -1568,6 +1568,7 @@ def get_provider_info():
"airflow.providers.google.cloud.links.translate.TranslationGlossariesListLink",
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink",
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink",
+
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink",
],
"secrets-backends": [
"airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend"
diff --git
a/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
new file mode 100644
index 00000000000..719891600b6
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_topic.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+"""
+Example Airflow DAG for Google Cloud Managed Service for Apache Kafka testing
Topic operations.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.managed_kafka import (
+ ManagedKafkaCreateClusterOperator,
+ ManagedKafkaCreateTopicOperator,
+ ManagedKafkaDeleteClusterOperator,
+ ManagedKafkaDeleteTopicOperator,
+ ManagedKafkaGetTopicOperator,
+ ManagedKafkaListTopicsOperator,
+ ManagedKafkaUpdateTopicOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "managed_kafka_topic_operations"
+LOCATION = "us-central1"
+
+CLUSTER_ID = f"cluster_{DAG_ID}_{ENV_ID}".replace("_", "-")
+CLUSTER_CONF = {
+ "gcp_config": {
+ "access_config": {
+ "network_configs": [
+ {"subnet":
f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/default"},
+ ],
+ },
+ },
+ "capacity_config": {
+ "vcpu_count": 3,
+ "memory_bytes": 3221225472,
+ },
+}
+TOPIC_ID = f"topic_{DAG_ID}_{ENV_ID}".replace("_", "-")
+TOPIC_CONF = {
+ "partition_count": 3,
+ "replication_factor": 3,
+}
+TOPIC_TO_UPDATE = {
+ "partition_count": 30,
+ "replication_factor": 3,
+}
+TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]}
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "managed_kafka", "topic"],
+) as dag:
+ create_cluster = ManagedKafkaCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster=CLUSTER_CONF,
+ cluster_id=CLUSTER_ID,
+ )
+
+ # [START how_to_cloud_managed_kafka_create_topic_operator]
+ create_topic = ManagedKafkaCreateTopicOperator(
+ task_id="create_topic",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ topic_id=TOPIC_ID,
+ topic=TOPIC_CONF,
+ )
+ # [END how_to_cloud_managed_kafka_create_topic_operator]
+
+ # [START how_to_cloud_managed_kafka_update_topic_operator]
+ update_topic = ManagedKafkaUpdateTopicOperator(
+ task_id="update_topic",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ topic_id=TOPIC_ID,
+ topic=TOPIC_TO_UPDATE,
+ update_mask=TOPIC_UPDATE_MASK,
+ )
+ # [END how_to_cloud_managed_kafka_update_topic_operator]
+
+ # [START how_to_cloud_managed_kafka_get_topic_operator]
+ get_topic = ManagedKafkaGetTopicOperator(
+ task_id="get_topic",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ topic_id=TOPIC_ID,
+ )
+ # [END how_to_cloud_managed_kafka_get_topic_operator]
+
+ # [START how_to_cloud_managed_kafka_delete_topic_operator]
+ delete_topic = ManagedKafkaDeleteTopicOperator(
+ task_id="delete_topic",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ topic_id=TOPIC_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END how_to_cloud_managed_kafka_delete_topic_operator]
+
+ delete_cluster = ManagedKafkaDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ # [START how_to_cloud_managed_kafka_list_topic_operator]
+ list_topics = ManagedKafkaListTopicsOperator(
+ task_id="list_topics",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ )
+ # [END how_to_cloud_managed_kafka_list_topic_operator]
+
+ (
+ # TEST SETUP
+ create_cluster
+ # TEST BODY
+ >> create_topic
+ >> update_topic
+ >> get_topic
+ >> list_topics
+ >> delete_topic
+ # TEST TEARDOWN
+ >> delete_cluster
+ )
+
+ # ### Everything below this line is not part of example ###
+ # ### Just for system tests purpose ###
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
index 16cb0d35cb9..7261f079555 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
@@ -55,6 +55,17 @@ TEST_UPDATED_CLUSTER: dict = {
},
}
+TEST_TOPIC_ID: str = "test-topic-id"
+TEST_TOPIC: dict = {
+ "partition_count": 1634,
+ "replication_factor": 1912,
+}
+TEST_TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]}
+TEST_UPDATED_TOPIC: dict = {
+ "partition_count": 2000,
+ "replication_factor": 1912,
+}
+
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
MANAGED_KAFKA_STRING = "airflow.providers.google.cloud.hooks.managed_kafka.{}"
@@ -174,6 +185,122 @@ class TestManagedKafkaWithDefaultProjectIdHook:
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID,
TEST_LOCATION)
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_create_topic(self, mock_client) -> None:
+ self.hook.create_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.create_topic.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_delete_topic(self, mock_client) -> None:
+ self.hook.delete_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.delete_topic.assert_called_once_with(
+
request=dict(name=mock_client.return_value.topic_path.return_value),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_get_topic(self, mock_client) -> None:
+ self.hook.get_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.get_topic.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.topic_path.return_value,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_TOPIC_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_update_topic(self, mock_client) -> None:
+ self.hook.update_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_UPDATED_TOPIC,
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.update_topic.assert_called_once_with(
+ request=dict(
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ topic={
+ "name": mock_client.return_value.topic_path.return_value,
+ **TEST_UPDATED_TOPIC,
+ },
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_list_topics(self, mock_client) -> None:
+ self.hook.list_topics(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.list_topics.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ page_size=None,
+ page_token=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
+
class TestManagedKafkaWithoutDefaultProjectIdHook:
def setup_method(self):
@@ -289,3 +416,122 @@ class TestManagedKafkaWithoutDefaultProjectIdHook:
timeout=None,
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID,
TEST_LOCATION)
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_create_topic(self, mock_client) -> None:
+ self.hook.create_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.create_topic.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_delete_topic(self, mock_client) -> None:
+ self.hook.delete_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.delete_topic.assert_called_once_with(
+
request=dict(name=mock_client.return_value.topic_path.return_value),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_TOPIC_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_get_topic(self, mock_client) -> None:
+ self.hook.get_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.get_topic.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.topic_path.return_value,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_TOPIC_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_update_topic(self, mock_client) -> None:
+ self.hook.update_topic(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_UPDATED_TOPIC,
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.update_topic.assert_called_once_with(
+ request=dict(
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ topic={
+ "name": mock_client.return_value.topic_path.return_value,
+ **TEST_UPDATED_TOPIC,
+ },
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.topic_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID, TEST_TOPIC_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_list_topics(self, mock_client) -> None:
+ self.hook.list_topics(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.list_topics.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ page_size=None,
+ page_token=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
diff --git
a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
index add83f74d56..7bf671c68e6 100644
--- a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
@@ -22,11 +22,13 @@ from unittest import mock
from airflow.providers.google.cloud.links.managed_kafka import (
ApacheKafkaClusterLink,
ApacheKafkaClusterListLink,
+ ApacheKafkaTopicLink,
)
TEST_LOCATION = "test-location"
TEST_CLUSTER_ID = "test-cluster-id"
TEST_PROJECT_ID = "test-project-id"
+TEST_TOPIC_ID = "test-topic-id"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME = "Apache Kafka Cluster"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY = "cluster_conf"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = (
@@ -35,6 +37,11 @@ EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = (
EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_NAME = "Apache Kafka Cluster List"
EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_KEY = "cluster_list_conf"
EXPECTED_MANAGED_KAFKA_CLUSTER_LIST_LINK_FORMAT_STR =
"/managedkafka/clusters?project={project_id}"
+EXPECTED_MANAGED_KAFKA_TOPIC_LINK_NAME = "Apache Kafka Topic"
+EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY = "topic_conf"
+EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR = (
+
"/managedkafka/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}"
+)
class TestApacheKafkaClusterLink:
@@ -87,3 +94,34 @@ class TestApacheKafkaClusterListLink:
"project_id": TEST_PROJECT_ID,
},
)
+
+
+class TestApacheKafkaTopicLink:
+ def test_class_attributes(self):
+ assert ApacheKafkaTopicLink.key ==
EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY
+ assert ApacheKafkaTopicLink.name ==
EXPECTED_MANAGED_KAFKA_TOPIC_LINK_NAME
+ assert ApacheKafkaTopicLink.format_str ==
EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR
+
+ def test_persist(self):
+ mock_context, mock_task_instance = (
+ mock.MagicMock(),
+ mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID),
+ )
+
+ ApacheKafkaTopicLink.persist(
+ context=mock_context,
+ task_instance=mock_task_instance,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ )
+
+ mock_task_instance.xcom_push.assert_called_once_with(
+ context=mock_context,
+ key=EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY,
+ value={
+ "location": TEST_LOCATION,
+ "cluster_id": TEST_CLUSTER_ID,
+ "topic_id": TEST_TOPIC_ID,
+ "project_id": TEST_PROJECT_ID,
+ },
+ )
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
index 4b5bc5c7125..e9407cc0a50 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
@@ -22,10 +22,15 @@ from google.api_core.retry import Retry
from airflow.providers.google.cloud.operators.managed_kafka import (
ManagedKafkaCreateClusterOperator,
+ ManagedKafkaCreateTopicOperator,
ManagedKafkaDeleteClusterOperator,
+ ManagedKafkaDeleteTopicOperator,
ManagedKafkaGetClusterOperator,
+ ManagedKafkaGetTopicOperator,
ManagedKafkaListClustersOperator,
+ ManagedKafkaListTopicsOperator,
ManagedKafkaUpdateClusterOperator,
+ ManagedKafkaUpdateTopicOperator,
)
MANAGED_KAFKA_PATH =
"airflow.providers.google.cloud.operators.managed_kafka.{}"
@@ -64,6 +69,17 @@ TEST_UPDATED_CLUSTER: dict = {
},
}
+TEST_TOPIC_ID: str = "test-topic-id"
+TEST_TOPIC: dict = {
+ "partition_count": 1634,
+ "replication_factor": 1912,
+}
+TEST_TOPIC_UPDATE_MASK: dict = {"paths": ["partition_count"]}
+TEST_UPDATED_TOPIC: dict = {
+ "partition_count": 2000,
+ "replication_factor": 1912,
+}
+
class TestManagedKafkaCreateClusterOperator:
@mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict"))
@@ -221,3 +237,159 @@ class TestManagedKafkaDeleteClusterOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+
+
+class TestManagedKafkaCreateTopicOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_dict_mock):
+ op = ManagedKafkaCreateTopicOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.create_topic.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_TOPIC,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaListTopicsOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.ListTopicsResponse.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_cluster_dict_mock,
to_clusters_dict_mock):
+ page_token = "page_token"
+ page_size = 42
+
+ op = ManagedKafkaListTopicsOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ page_size=page_size,
+ page_token=page_token,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.list_topics.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ page_size=page_size,
+ page_token=page_token,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaGetTopicOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_dict_mock):
+ op = ManagedKafkaGetTopicOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.get_topic.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaUpdateTopicOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.Topic.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_dict_mock):
+ op = ManagedKafkaUpdateTopicOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_UPDATED_TOPIC,
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.update_topic.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ topic=TEST_UPDATED_TOPIC,
+ update_mask=TEST_TOPIC_UPDATE_MASK,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaDeleteTopicOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook):
+ op = ManagedKafkaDeleteTopicOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.delete_topic.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ topic_id=TEST_TOPIC_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )