This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 778e8c50b9 Pinecone provider support for `pinecone-client`>=3 (#37307)
778e8c50b9 is described below
commit 778e8c50b987176b15689bb681ac4c48d7a7805a
Author: Kalyan <[email protected]>
AuthorDate: Tue Apr 30 16:20:19 2024 +0530
Pinecone provider support for `pinecone-client`>=3 (#37307)
---
airflow/providers/pinecone/CHANGELOG.rst | 30 +++
airflow/providers/pinecone/hooks/pinecone.py | 221 +++++++++++++--------
airflow/providers/pinecone/operators/pinecone.py | 130 ++++++++++++
airflow/providers/pinecone/provider.yaml | 5 +-
.../connections.rst | 12 +-
docs/apache-airflow-providers-pinecone/index.rst | 4 +-
.../operators/pinecone.rst | 52 ++++-
generated/provider_dependencies.json | 2 +-
tests/providers/pinecone/hooks/test_pinecone.py | 44 +++-
.../providers/pinecone/example_create_pod_index.py | 51 +++++
.../pinecone/example_create_serverless_index.py | 50 +++++
11 files changed, 496 insertions(+), 105 deletions(-)
diff --git a/airflow/providers/pinecone/CHANGELOG.rst
b/airflow/providers/pinecone/CHANGELOG.rst
index 7b2a20deb0..a1482f9534 100644
--- a/airflow/providers/pinecone/CHANGELOG.rst
+++ b/airflow/providers/pinecone/CHANGELOG.rst
@@ -20,6 +20,36 @@
Changelog
---------
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+.. warning::
+ This release of provider has breaking changes from previous versions.
Changes are based on
+ the migration guide from pinecone -
<https://canyon-quilt-082.notion.site/Pinecone-Python-SDK-v3-0-0-Migration-Guide-056d3897d7634bf7be399676a4757c7b>
+
+* ``log_level`` field is removed from the Connections as it is not used by the
provider anymore.
+* ``PineconeHook.get_conn`` is removed in favor of ``conn`` property which
returns the Connection object. Use ``pinecone_client`` property to access the
Pinecone client.
+* Following ``PineconeHook`` methods are converted from static methods to
instance methods. Hence, Initialization is required to use these now:
+
+ + ``PineconeHook.list_indexes``
+ + ``PineconeHook.upsert``
+ + ``PineconeHook.create_index``
+ + ``PineconeHook.describe_index``
+ + ``PineconeHook.delete_index``
+ + ``PineconeHook.configure_index``
+ + ``PineconeHook.create_collection``
+ + ``PineconeHook.delete_collection``
+ + ``PineconeHook.describe_collection``
+ + ``PineconeHook.list_collections``
+ + ``PineconeHook.query_vector``
+ + ``PineconeHook.describe_index_stats``
+
+* ``PineconeHook.create_index`` is updated to accept a ``ServerlessSpec`` or
``PodSpec`` instead of directly accepting index related configurations
+* To initialize ``PineconeHook`` object, API key needs to be passed via
argument or the connection.
+
1.1.2
.....
diff --git a/airflow/providers/pinecone/hooks/pinecone.py
b/airflow/providers/pinecone/hooks/pinecone.py
index 3d11c74b64..a04ae60ce8 100644
--- a/airflow/providers/pinecone/hooks/pinecone.py
+++ b/airflow/providers/pinecone/hooks/pinecone.py
@@ -20,9 +20,11 @@
from __future__ import annotations
import itertools
+import os
+from functools import cached_property
from typing import TYPE_CHECKING, Any
-import pinecone
+from pinecone import Pinecone, PodSpec, ServerlessSpec
from airflow.hooks.base import BaseHook
@@ -30,6 +32,8 @@ if TYPE_CHECKING:
from pinecone.core.client.model.sparse_values import SparseValues
from pinecone.core.client.models import DescribeIndexStatsResponse,
QueryResponse, UpsertResponse
+ from airflow.models.connection import Connection
+
class PineconeHook(BaseHook):
"""
@@ -49,10 +53,11 @@ class PineconeHook(BaseHook):
"""Return connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
- from wtforms import StringField
+ from wtforms import BooleanField, StringField
return {
- "log_level": StringField(lazy_gettext("Log Level"),
widget=BS3TextFieldWidget(), default=None),
+ "region": StringField(lazy_gettext("Pinecone Region"),
widget=BS3TextFieldWidget(), default=None),
+ "debug_curl": BooleanField(lazy_gettext("PINECONE_DEBUG_CURL"),
default=False),
"project_id": StringField(
lazy_gettext("Project ID"),
widget=BS3TextFieldWidget(),
@@ -64,43 +69,73 @@ class PineconeHook(BaseHook):
"""Return custom field behaviour."""
return {
"hidden_fields": ["port", "schema"],
- "relabeling": {"login": "Pinecone Environment", "password":
"Pinecone API key"},
+ "relabeling": {
+ "login": "Pinecone Environment",
+ "host": "Pinecone Host",
+ "password": "Pinecone API key",
+ },
}
- def __init__(self, conn_id: str = default_conn_name) -> None:
+ def __init__(
+ self, conn_id: str = default_conn_name, environment: str | None =
None, region: str | None = None
+ ) -> None:
self.conn_id = conn_id
- self.get_conn()
-
- def get_conn(self) -> None:
- pinecone_connection = self.get_connection(self.conn_id)
- api_key = pinecone_connection.password
- pinecone_environment = pinecone_connection.login
- pinecone_host = pinecone_connection.host
- extras = pinecone_connection.extra_dejson
+ self._environment = environment
+ self._region = region
+
+ @property
+ def api_key(self) -> str:
+ key = self.conn.password
+ if not key:
+ raise LookupError("Pinecone API Key not found in connection")
+ return key
+
+ @cached_property
+ def environment(self) -> str:
+ if self._environment:
+ return self._environment
+ env = self.conn.login
+ if not env:
+ raise LookupError("Pinecone environment not found in connection")
+ return env
+
+ @cached_property
+ def region(self) -> str:
+ if self._region:
+ return self._region
+ region = self.conn.extra_dejson.get("region")
+ if not region:
+ raise LookupError("Pinecone region not found in connection")
+ return region
+
+ @cached_property
+ def pinecone_client(self) -> Pinecone:
+ """Pinecone object to interact with Pinecone."""
+ pinecone_host = self.conn.host
+ extras = self.conn.extra_dejson
pinecone_project_id = extras.get("project_id")
- log_level = extras.get("log_level", None)
- pinecone.init(
- api_key=api_key,
- environment=pinecone_environment,
- host=pinecone_host,
- project_name=pinecone_project_id,
- log_level=log_level,
- )
+ enable_curl_debug = extras.get("debug_curl")
+ if enable_curl_debug:
+ os.environ["PINECONE_DEBUG_CURL"] = "true"
+ return Pinecone(api_key=self.api_key, host=pinecone_host,
project_id=pinecone_project_id)
+
+ @cached_property
+ def conn(self) -> Connection:
+ return self.get_connection(self.conn_id)
def test_connection(self) -> tuple[bool, str]:
try:
- self.list_indexes()
+ self.pinecone_client.list_indexes()
return True, "Connection established"
except Exception as e:
return False, str(e)
- @staticmethod
- def list_indexes() -> Any:
+ def list_indexes(self) -> Any:
"""Retrieve a list of all indexes in your project."""
- return pinecone.list_indexes()
+ return self.pinecone_client.list_indexes()
- @staticmethod
def upsert(
+ self,
index_name: str,
vectors: list[Any],
namespace: str = "",
@@ -126,7 +161,7 @@ class PineconeHook(BaseHook):
:param show_progress: Whether to show a progress bar using tqdm.
Applied only
if batch_size is provided.
"""
- index = pinecone.Index(index_name)
+ index = self.pinecone_client.Index(index_name)
return index.upsert(
vectors=vectors,
namespace=namespace,
@@ -135,75 +170,93 @@ class PineconeHook(BaseHook):
**kwargs,
)
- @staticmethod
+ def get_pod_spec_obj(
+ self,
+ *,
+ replicas: int | None = None,
+ shards: int | None = None,
+ pods: int | None = None,
+ pod_type: str | None = "p1.x1",
+ metadata_config: dict | None = None,
+ source_collection: str | None = None,
+ environment: str | None = None,
+ ) -> PodSpec:
+ """
+ Get a PodSpec object.
+
+ :param replicas: The number of replicas.
+ :param shards: The number of shards.
+ :param pods: The number of pods.
+ :param pod_type: The type of pod.
+ :param metadata_config: The metadata configuration.
+ :param source_collection: The source collection.
+ :param environment: The environment to use when creating the index.
+ """
+ return PodSpec(
+ environment=environment or self.environment,
+ replicas=replicas,
+ shards=shards,
+ pods=pods,
+ pod_type=pod_type,
+ metadata_config=metadata_config,
+ source_collection=source_collection,
+ )
+
+ def get_serverless_spec_obj(self, *, cloud: str, region: str | None =
None) -> ServerlessSpec:
+ """
+ Get a ServerlessSpec object.
+
+ :param cloud: The cloud provider.
+ :param region: The region to use when creating the index.
+ """
+ return ServerlessSpec(cloud=cloud, region=region or self.region)
+
def create_index(
+ self,
index_name: str,
dimension: int,
- index_type: str | None = "approximated",
+ spec: ServerlessSpec | PodSpec,
metric: str | None = "cosine",
- replicas: int | None = 1,
- shards: int | None = 1,
- pods: int | None = 1,
- pod_type: str | None = "p1",
- index_config: dict[str, str] | None = None,
- metadata_config: dict[str, str] | None = None,
- source_collection: str | None = "",
timeout: int | None = None,
) -> None:
"""
Create a new index.
- .. seealso:: https://docs.pinecone.io/reference/create_index/
-
- :param index_name: The name of the index to create.
- :param dimension: the dimension of vectors that would be inserted in
the index
- :param index_type: type of index, one of {"approximated", "exact"},
defaults to "approximated".
- :param metric: type of metric used in the vector index, one of
{"cosine", "dotproduct", "euclidean"}
- :param replicas: the number of replicas, defaults to 1.
- :param shards: the number of shards per index, defaults to 1.
- :param pods: Total number of pods to be used by the index. pods =
shard*replicas
- :param pod_type: the pod type to be used for the index. can be one of
p1 or s1.
- :param index_config: Advanced configuration options for the index
- :param metadata_config: Configuration related to the metadata index
- :param source_collection: Collection name to create the index from
- :param timeout: Timeout for wait until index gets ready.
+ :param index_name: The name of the index.
+ :param dimension: The dimension of the vectors to be indexed.
+ :param spec: Pass a `ServerlessSpec` object to create a serverless
index or a `PodSpec` object to create a pod index.
+ ``get_serverless_spec_obj`` and ``get_pod_spec_obj`` can be used
to create the Spec objects.
+ :param metric: The metric to use.
+ :param timeout: The timeout to use.
"""
- pinecone.create_index(
+ self.pinecone_client.create_index(
name=index_name,
- timeout=timeout,
- index_type=index_type,
dimension=dimension,
+ spec=spec,
metric=metric,
- pods=pods,
- replicas=replicas,
- shards=shards,
- pod_type=pod_type,
- metadata_config=metadata_config,
- source_collection=source_collection,
- index_config=index_config,
+ timeout=timeout,
)
- @staticmethod
- def describe_index(index_name: str) -> Any:
+ def describe_index(self, index_name: str) -> Any:
"""
Retrieve information about a specific index.
:param index_name: The name of the index to describe.
"""
- return pinecone.describe_index(name=index_name)
+ return self.pinecone_client.describe_index(name=index_name)
- @staticmethod
- def delete_index(index_name: str, timeout: int | None = None) -> None:
+ def delete_index(self, index_name: str, timeout: int | None = None) ->
None:
"""
Delete a specific index.
:param index_name: the name of the index.
:param timeout: Timeout for wait until index gets ready.
"""
- pinecone.delete_index(name=index_name, timeout=timeout)
+ self.pinecone_client.delete_index(name=index_name, timeout=timeout)
- @staticmethod
- def configure_index(index_name: str, replicas: int | None = None,
pod_type: str | None = "") -> None:
+ def configure_index(
+ self, index_name: str, replicas: int | None = None, pod_type: str |
None = ""
+ ) -> None:
"""
Change the current configuration of the index.
@@ -211,43 +264,39 @@ class PineconeHook(BaseHook):
:param replicas: The new number of replicas.
:param pod_type: the new pod_type for the index.
"""
- pinecone.configure_index(name=index_name, replicas=replicas,
pod_type=pod_type)
+ self.pinecone_client.configure_index(name=index_name,
replicas=replicas, pod_type=pod_type)
- @staticmethod
- def create_collection(collection_name: str, index_name: str) -> None:
+ def create_collection(self, collection_name: str, index_name: str) -> None:
"""
Create a new collection from a specified index.
:param collection_name: The name of the collection to create.
:param index_name: The name of the source index.
"""
- pinecone.create_collection(name=collection_name, source=index_name)
+ self.pinecone_client.create_collection(name=collection_name,
source=index_name)
- @staticmethod
- def delete_collection(collection_name: str) -> None:
+ def delete_collection(self, collection_name: str) -> None:
"""
Delete a specific collection.
:param collection_name: The name of the collection to delete.
"""
- pinecone.delete_collection(collection_name)
+ self.pinecone_client.delete_collection(collection_name)
- @staticmethod
- def describe_collection(collection_name: str) -> Any:
+ def describe_collection(self, collection_name: str) -> Any:
"""
Retrieve information about a specific collection.
:param collection_name: The name of the collection to describe.
"""
- return pinecone.describe_collection(collection_name)
+ return self.pinecone_client.describe_collection(collection_name)
- @staticmethod
- def list_collections() -> Any:
+ def list_collections(self) -> Any:
"""Retrieve a list of all collections in the current project."""
- return pinecone.list_collections()
+ return self.pinecone_client.list_collections()
- @staticmethod
def query_vector(
+ self,
index_name: str,
vector: list[Any],
query_id: str | None = None,
@@ -275,7 +324,7 @@ class PineconeHook(BaseHook):
:param sparse_vector: sparse values of the query vector. Expected to
be either a SparseValues object or a dict
of the form: {'indices': List[int], 'values': List[float]}, where the
lists each have the same length.
"""
- index = pinecone.Index(index_name)
+ index = self.pinecone_client.Index(index_name)
return index.query(
vector=vector,
id=query_id,
@@ -313,7 +362,7 @@ class PineconeHook(BaseHook):
:param pool_threads: Number of threads for parallel upserting. If
async_req is True, this must be provided.
"""
responses = []
- with pinecone.Index(index_name, pool_threads=pool_threads) as index:
+ with self.pinecone_client.Index(index_name, pool_threads=pool_threads)
as index:
if async_req and pool_threads:
async_results = [index.upsert(vectors=chunk, async_req=True)
for chunk in self._chunks(data)]
responses = [async_result.get() for async_result in
async_results]
@@ -323,8 +372,8 @@ class PineconeHook(BaseHook):
responses.append(response)
return responses
- @staticmethod
def describe_index_stats(
+ self,
index_name: str,
stats_filter: dict[str, str | float | int | bool | list[Any] |
dict[Any, Any]] | None = None,
**kwargs: Any,
@@ -340,5 +389,5 @@ class PineconeHook(BaseHook):
:param stats_filter: If this parameter is present, the operation only
returns statistics for vectors that
satisfy the filter. See
https://www.pinecone.io/docs/metadata-filtering/
"""
- index = pinecone.Index(index_name)
+ index = self.pinecone_client.Index(index_name)
return index.describe_index_stats(filter=stats_filter, **kwargs)
diff --git a/airflow/providers/pinecone/operators/pinecone.py
b/airflow/providers/pinecone/operators/pinecone.py
index 1c757d8fa5..8431276206 100644
--- a/airflow/providers/pinecone/operators/pinecone.py
+++ b/airflow/providers/pinecone/operators/pinecone.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.pinecone.hooks.pinecone import PineconeHook
+from airflow.utils.context import Context
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -81,3 +82,132 @@ class PineconeIngestOperator(BaseOperator):
)
self.log.info("Successfully ingested data into Pinecone index %s.",
self.index_name)
+
+
+class CreatePodIndexOperator(BaseOperator):
+ """
+ Create a pod based index in Pinecone.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:CreatePodIndexOperator`
+
+ :param conn_id: The connection id to use when connecting to Pinecone.
+ :param index_name: Name of the Pinecone index.
+ :param dimension: The dimension of the vectors to be indexed.
+ :param environment: The environment to use when creating the index.
+ :param replicas: The number of replicas to use.
+ :param shards: The number of shards to use.
+ :param pods: The number of pods to use.
+ :param pod_type: The type of pod to use.
+ :param metadata_config: The metadata configuration to use.
+ :param source_collection: The source collection to use.
+ :param metric: The metric to use.
+ :param timeout: The timeout to use.
+ """
+
+ def __init__(
+ self,
+ *,
+ conn_id: str = PineconeHook.default_conn_name,
+ index_name: str,
+ dimension: int,
+ environment: str | None = None,
+ replicas: int | None = None,
+ shards: int | None = None,
+ pods: int | None = None,
+ pod_type: str | None = None,
+ metadata_config: dict | None = None,
+ source_collection: str | None = None,
+ metric: str | None = None,
+ timeout: int | None = None,
+ **kwargs: Any,
+ ):
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.index_name = index_name
+ self.dimension = dimension
+ self.environment = environment
+ self.replicas = replicas
+ self.shards = shards
+ self.pods = pods
+ self.pod_type = pod_type
+ self.metadata_config = metadata_config
+ self.source_collection = source_collection
+ self.metric = metric
+ self.timeout = timeout
+
+ @cached_property
+ def hook(self) -> PineconeHook:
+ return PineconeHook(conn_id=self.conn_id, environment=self.environment)
+
+ def execute(self, context: Context) -> None:
+ pod_spec_obj = self.hook.get_pod_spec_obj(
+ replicas=self.replicas,
+ shards=self.shards,
+ pods=self.pods,
+ pod_type=self.pod_type,
+ metadata_config=self.metadata_config,
+ source_collection=self.source_collection,
+ environment=self.environment,
+ )
+ self.hook.create_index(
+ index_name=self.index_name,
+ dimension=self.dimension,
+ spec=pod_spec_obj,
+ metric=self.metric,
+ timeout=self.timeout,
+ )
+
+
+class CreateServerlessIndexOperator(BaseOperator):
+ """
+ Create a serverless index in Pinecone.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:CreateServerlessIndexOperator`
+
+ :param conn_id: The connection id to use when connecting to Pinecone.
+ :param index_name: Name of the Pinecone index.
+ :param dimension: The dimension of the vectors to be indexed.
+ :param cloud: The cloud to use when creating the index.
+ :param region: The region to use when creating the index.
+ :param metric: The metric to use.
+ :param timeout: The timeout to use.
+ """
+
+ def __init__(
+ self,
+ *,
+ conn_id: str = PineconeHook.default_conn_name,
+ index_name: str,
+ dimension: int,
+ cloud: str,
+ region: str | None = None,
+ metric: str | None = None,
+ timeout: int | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.index_name = index_name
+ self.dimension = dimension
+ self.cloud = cloud
+ self.region = region
+ self.metric = metric
+ self.timeout = timeout
+
+ @cached_property
+ def hook(self) -> PineconeHook:
+ return PineconeHook(conn_id=self.conn_id, region=self.region)
+
+ def execute(self, context: Context) -> None:
+ serverless_spec_obj =
self.hook.get_serverless_spec_obj(cloud=self.cloud, region=self.region)
+ self.hook.create_index(
+ index_name=self.index_name,
+ dimension=self.dimension,
+ spec=serverless_spec_obj,
+ metric=self.metric,
+ timeout=self.timeout,
+ )
diff --git a/airflow/providers/pinecone/provider.yaml
b/airflow/providers/pinecone/provider.yaml
index a48f041fa8..0c6e3d9b4c 100644
--- a/airflow/providers/pinecone/provider.yaml
+++ b/airflow/providers/pinecone/provider.yaml
@@ -42,10 +42,7 @@ integrations:
dependencies:
- apache-airflow>=2.7.0
- # Pinecone Python SDK v3.0.0 was released at 2024-01-16 and introduce some
breaking changes.
- # It's crucial to adhere to the v3.0.0 Migration Guide before the
upper-bound limitation can be removed.
- #
https://canyon-quilt-082.notion.site/Pinecone-Python-SDK-v3-0-0-Migration-Guide-056d3897d7634bf7be399676a4757c7b
- - pinecone-client>=2.2.4,<3.0
+ - pinecone-client>=3.0.0
hooks:
- integration-name: Pinecone
diff --git a/docs/apache-airflow-providers-pinecone/connections.rst
b/docs/apache-airflow-providers-pinecone/connections.rst
index 07054a9388..50a72b133a 100644
--- a/docs/apache-airflow-providers-pinecone/connections.rst
+++ b/docs/apache-airflow-providers-pinecone/connections.rst
@@ -33,11 +33,17 @@ Configuring the Connection
Host (optional)
Host URL to connect to a specific Pinecone index.
-Pinecone Environment (required)
- Specify your Pinecone environment to connect to.
+Pinecone Environment (optional)
+ Specify your Pinecone environment for pod based indexes.
Pinecone API key (required)
Specify your Pinecone API Key to connect.
-Project ID (required)
+Project ID (optional)
Project ID corresponding to your API Key.
+
+Pinecone Region (optional)
+ Specify the region for Serverless Indexes in Pinecone.
+
+PINECONE_DEBUG_CURL (optional)
+ Set to ``true`` to enable curl debug output.
diff --git a/docs/apache-airflow-providers-pinecone/index.rst
b/docs/apache-airflow-providers-pinecone/index.rst
index d82935e318..91913b9168 100644
--- a/docs/apache-airflow-providers-pinecone/index.rst
+++ b/docs/apache-airflow-providers-pinecone/index.rst
@@ -69,7 +69,7 @@ Package apache-airflow-providers-pinecone
`Pinecone <https://docs.pinecone.io/docs/overview>`__
-Release: 1.1.2
+Release: 2.0.0
Provider package
----------------
@@ -93,5 +93,5 @@ The minimum Apache Airflow version supported by this provider
package is ``2.6.0
PIP package Version required
=================== ==================
``apache-airflow`` ``>=2.6.0``
-``pinecone-client`` ``>=2.2.4,<3.0``
+``pinecone-client`` ``>=3.0.0``
=================== ==================
diff --git a/docs/apache-airflow-providers-pinecone/operators/pinecone.rst
b/docs/apache-airflow-providers-pinecone/operators/pinecone.rst
index 71f847919f..b50e5300f0 100644
--- a/docs/apache-airflow-providers-pinecone/operators/pinecone.rst
+++ b/docs/apache-airflow-providers-pinecone/operators/pinecone.rst
@@ -15,10 +15,13 @@
specific language governing permissions and limitations
under the License.
+Operators
+---------
+
.. _howto/operator:PineconeIngestOperator:
-PineconeIngestOperator
-======================
+Ingest data into a pinecone index
+=================================
Use the
:class:`~airflow.providers.pinecone.operators.pinecone.PineconeIngestOperator`
to
interact with Pinecone APIs to ingest vectors.
@@ -38,3 +41,48 @@ An example using the operator in this way:
:dedent: 4
:start-after: [START howto_operator_pinecone_ingest]
:end-before: [END howto_operator_pinecone_ingest]
+
+.. _howto/operator:CreatePodIndexOperator:
+
+Create a Pod based Index
+========================
+
+Use the
:class:`~airflow.providers.pinecone.operators.pinecone.CreatePodIndexOperator`
to
+interact with Pinecone APIs to create a Pod based Index.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+The ``CreatePodIndexOperator`` requires the index details as well as the pod
configuration details. ``api_key``, ``environment`` can be
+passed via arguments to the operator or via the connection.
+
+An example using the operator in this way:
+
+.. exampleinclude::
/../../tests/system/providers/pinecone/example_create_pod_index.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_create_pod_index]
+ :end-before: [END howto_operator_create_pod_index]
+
+
+.. _howto/operator:CreateServerlessIndexOperator:
+
+Create a Serverless Index
+=========================
+
+Use the
:class:`~airflow.providers.pinecone.operators.pinecone.CreateServerlessIndexOperator`
to
+interact with Pinecone APIs to create a Pod based Index.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+The ``CreateServerlessIndexOperator`` requires the index details as well as
the Serverless configuration details. ``api_key``, ``environment`` can be
+passed via arguments to the operator or via the connection.
+
+An example using the operator in this way:
+
+.. exampleinclude::
/../../tests/system/providers/pinecone/example_create_serverless_index.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_create_serverless_index]
+ :end-before: [END howto_operator_create_serverless_index]
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index d91b488580..6b3d2e7d81 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -910,7 +910,7 @@
"pinecone": {
"deps": [
"apache-airflow>=2.7.0",
- "pinecone-client>=2.2.4,<3.0"
+ "pinecone-client>=3.0.0"
],
"devel-deps": [],
"cross-providers-deps": [],
diff --git a/tests/providers/pinecone/hooks/test_pinecone.py
b/tests/providers/pinecone/hooks/test_pinecone.py
index fb076cc0a3..82a01e5319 100644
--- a/tests/providers/pinecone/hooks/test_pinecone.py
+++ b/tests/providers/pinecone/hooks/test_pinecone.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import os
from unittest.mock import Mock, patch
from airflow.providers.pinecone.hooks.pinecone import PineconeHook
@@ -28,13 +29,15 @@ class TestPineconeHook:
with patch("airflow.models.Connection.get_connection_from_secrets") as
mock_get_connection:
mock_conn = Mock()
mock_conn.host = "pinecone.io"
- mock_conn.login = "test_user"
- mock_conn.password = "test_password"
+ mock_conn.login = "us-west1-gcp" # Pinecone Environment
+ mock_conn.password = "test_password" # Pinecone API Key
+ mock_conn.extra_dejson = {"region": "us-east-1", "debug_curl":
True}
mock_get_connection.return_value = mock_conn
self.pinecone_hook = PineconeHook()
+ self.pinecone_hook.conn
self.index_name = "test_index"
- @patch("airflow.providers.pinecone.hooks.pinecone.pinecone.Index")
+ @patch("airflow.providers.pinecone.hooks.pinecone.Pinecone.Index")
def test_upsert(self, mock_index):
"""Test the upsert_data_async method of PineconeHook for correct data
insertion asynchronously."""
data = [("id1", [1.0, 2.0, 3.0], {"meta": "data"})]
@@ -49,11 +52,38 @@ class TestPineconeHook:
self.pinecone_hook.list_indexes()
mock_list_indexes.assert_called_once()
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_indexes")
+ def test_debug_curl_setting(self, mock_list_indexes):
+ """Test that the PINECONE_DEBUG_CURL environment variable is set when
initializing Pinecone Object."""
+ self.pinecone_hook.list_indexes()
+ mock_list_indexes.assert_called_once()
+ assert os.environ.get("PINECONE_DEBUG_CURL") == "true"
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index")
+ def test_create_index_for_pod_based(self, mock_create_index):
+ """Test that the create_index method of PineconeHook is called with
correct arguments for pod based index."""
+ pod_spec = self.pinecone_hook.get_pod_spec_obj()
+ self.pinecone_hook.create_index(index_name=self.index_name,
dimension=128, spec=pod_spec)
+ mock_create_index.assert_called_once_with(index_name="test_index",
dimension=128, spec=pod_spec)
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index")
- def test_create_index(self, mock_create_index):
- """Test that the create_index method of PineconeHook is called with
correct arguments."""
- self.pinecone_hook.create_index(index_name=self.index_name,
dimension=128)
- mock_create_index.assert_called_once_with(index_name="test_index",
dimension=128)
+ def test_create_index_for_serverless_based(self, mock_create_index):
+ """Test that the create_index method of PineconeHook is called with
correct arguments for serverless index."""
+ serverless_spec =
self.pinecone_hook.get_serverless_spec_obj(cloud="aws")
+ self.pinecone_hook.create_index(index_name=self.index_name,
dimension=128, spec=serverless_spec)
+ mock_create_index.assert_called_once_with(
+ index_name="test_index", dimension=128, spec=serverless_spec
+ )
+
+ def test_get_pod_spec_obj(self):
+ """Test that the get_pod_spec_obj method of PineconeHook returns the
correct pod spec object."""
+ pod_spec = self.pinecone_hook.get_pod_spec_obj()
+ assert pod_spec.environment == "us-west1-gcp"
+
+ def test_get_serverless_spec_obj(self):
+ """Test that the get_serverless_spec_obj method of PineconeHook
returns the correct serverless spec object."""
+ serverless_spec =
self.pinecone_hook.get_serverless_spec_obj(cloud="gcp")
+ assert serverless_spec.region == "us-east-1"
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index")
def test_describe_index(self, mock_describe_index):
diff --git a/tests/system/providers/pinecone/example_create_pod_index.py
b/tests/system/providers/pinecone/example_create_pod_index.py
new file mode 100644
index 0000000000..9b6f7d7d88
--- /dev/null
+++ b/tests/system/providers/pinecone/example_create_pod_index.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.providers.pinecone.operators.pinecone import
CreatePodIndexOperator
+
+index_name = os.getenv("INDEX_NAME", "test")
+
+
+with DAG(
+ "example_pinecone_create_pod_index",
+ schedule="@once",
+ start_date=datetime(2024, 1, 1),
+ catchup=False,
+) as dag:
+ # [START howto_operator_create_pod_index]
+ # reference:
https://docs.pinecone.io/reference/api/control-plane/create_index
+ CreatePodIndexOperator(
+ task_id="pinecone_create_pod_index",
+ index_name=index_name,
+ dimension=3,
+ replicas=1,
+ shards=1,
+ pods=1,
+ pod_type="p1.x1",
+ )
+ # [END howto_operator_create_pod_index]
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/pinecone/example_create_serverless_index.py
b/tests/system/providers/pinecone/example_create_serverless_index.py
new file mode 100644
index 0000000000..a7924e63ef
--- /dev/null
+++ b/tests/system/providers/pinecone/example_create_serverless_index.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.providers.pinecone.operators.pinecone import
CreateServerlessIndexOperator
+
+index_name = os.getenv("INDEX_NAME", "test")
+
+
+with DAG(
+ "example_pinecone_create_serverless_index",
+ schedule="@once",
+ start_date=datetime(2024, 1, 1),
+ catchup=False,
+) as dag:
+ # [START howto_operator_create_serverless_index]
+ # reference:
https://docs.pinecone.io/reference/api/control-plane/create_index
+ CreateServerlessIndexOperator(
+ task_id="pinecone_create_serverless_index",
+ index_name=index_name,
+ dimension=128,
+ cloud="aws",
+ region="us-west-2",
+ metric="cosine",
+ )
+ # [END howto_operator_create_serverless_index]
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)