This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 65020ee66a Add utils methods in pinecone provider (#35502)
65020ee66a is described below
commit 65020ee66afa803f9bda226f176233e47b59a8d0
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Tue Nov 7 20:09:46 2023 +0530
Add utils methods in pinecone provider (#35502)
* Add missing methods to pinecone provider
* Fix static
* Add words to spelling-wordlist.txt
* Update airflow/providers/pinecone/hooks/pinecone.py
Co-authored-by: Pankaj Singh <[email protected]>
* Update airflow/providers/pinecone/hooks/pinecone.py
Co-authored-by: Pankaj Singh <[email protected]>
---------
Co-authored-by: Pankaj Singh <[email protected]>
---
airflow/providers/pinecone/hooks/pinecone.py | 219 +++++++++++++++++++++++-
docs/spelling_wordlist.txt | 2 +
tests/providers/pinecone/hooks/test_pinecone.py | 89 ++++++++++
3 files changed, 308 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/pinecone/hooks/pinecone.py
b/airflow/providers/pinecone/hooks/pinecone.py
index 92fd620f76..9f6054ffe1 100644
--- a/airflow/providers/pinecone/hooks/pinecone.py
+++ b/airflow/providers/pinecone/hooks/pinecone.py
@@ -18,6 +18,7 @@
"""Hook for Pinecone."""
from __future__ import annotations
+import itertools
from typing import TYPE_CHECKING, Any
import pinecone
@@ -25,7 +26,8 @@ import pinecone
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
- from pinecone.core.client.models import UpsertResponse
+ from pinecone.core.client.model.sparse_values import SparseValues
+ from pinecone.core.client.models import DescribeIndexStatsResponse,
QueryResponse, UpsertResponse
class PineconeHook(BaseHook):
@@ -86,11 +88,16 @@ class PineconeHook(BaseHook):
def test_connection(self) -> tuple[bool, str]:
try:
- pinecone.list_indexes()
+ self.list_indexes()
return True, "Connection established"
except Exception as e:
return False, str(e)
+ @staticmethod
+ def list_indexes() -> Any:
+ """Retrieve a list of all indexes in your project."""
+ return pinecone.list_indexes()
+
@staticmethod
def upsert(
index_name: str,
@@ -126,3 +133,211 @@ class PineconeHook(BaseHook):
show_progress=show_progress,
**kwargs,
)
+
+ @staticmethod
+ def create_index(
+ index_name: str,
+ dimension: int,
+ index_type: str | None = "approximated",
+ metric: str | None = "cosine",
+ replicas: int | None = 1,
+ shards: int | None = 1,
+ pods: int | None = 1,
+ pod_type: str | None = "p1",
+ index_config: dict[str, str] | None = None,
+ metadata_config: dict[str, str] | None = None,
+ source_collection: str | None = "",
+ timeout: int | None = None,
+ ) -> None:
+ """
+ Create a new index.
+
+ .. seealso:: https://docs.pinecone.io/reference/create_index/
+
+ :param index_name: The name of the index to create.
+ :param dimension: the dimension of vectors that would be inserted in
the index
+ :param index_type: type of index, one of {"approximated", "exact"},
defaults to "approximated".
+ :param metric: type of metric used in the vector index, one of
{"cosine", "dotproduct", "euclidean"}
+ :param replicas: the number of replicas, defaults to 1.
+ :param shards: the number of shards per index, defaults to 1.
+ :param pods: Total number of pods to be used by the index. pods =
shard*replicas
+ :param pod_type: the pod type to be used for the index. can be one of
p1 or s1.
+ :param index_config: Advanced configuration options for the index
+ :param metadata_config: Configuration related to the metadata index
+ :param source_collection: Collection name to create the index from
+ :param timeout: Timeout for wait until index gets ready.
+ """
+ pinecone.create_index(
+ name=index_name,
+ timeout=timeout,
+ index_type=index_type,
+ dimension=dimension,
+ metric=metric,
+ pods=pods,
+ replicas=replicas,
+ shards=shards,
+ pod_type=pod_type,
+ metadata_config=metadata_config,
+ source_collection=source_collection,
+ index_config=index_config,
+ )
+
+ @staticmethod
+ def describe_index(index_name: str) -> Any:
+ """
+ Retrieve information about a specific index.
+
+ :param index_name: The name of the index to describe.
+ """
+ return pinecone.describe_index(name=index_name)
+
+ @staticmethod
+ def delete_index(index_name: str, timeout: int | None = None) -> None:
+ """
+ Delete a specific index.
+
+ :param index_name: the name of the index.
+ :param timeout: Timeout for wait until index gets ready.
+ """
+ pinecone.delete_index(name=index_name, timeout=timeout)
+
+ @staticmethod
+ def configure_index(index_name: str, replicas: int | None = None,
pod_type: str | None = "") -> None:
+ """
+ Changes current configuration of the index.
+
+ :param index_name: The name of the index to configure.
+ :param replicas: The new number of replicas.
+ :param pod_type: the new pod_type for the index.
+ """
+ pinecone.configure_index(name=index_name, replicas=replicas,
pod_type=pod_type)
+
+ @staticmethod
+ def create_collection(collection_name: str, index_name: str) -> None:
+ """
+ Create a new collection from a specified index.
+
+ :param collection_name: The name of the collection to create.
+ :param index_name: The name of the source index.
+ """
+ pinecone.create_collection(name=collection_name, source=index_name)
+
+ @staticmethod
+ def delete_collection(collection_name: str) -> None:
+ """
+ Delete a specific collection.
+
+ :param collection_name: The name of the collection to delete.
+ """
+ pinecone.delete_collection(collection_name)
+
+ @staticmethod
+ def describe_collection(collection_name: str) -> Any:
+ """
+ Retrieve information about a specific collection.
+
+ :param collection_name: The name of the collection to describe.
+ """
+ return pinecone.describe_collection(collection_name)
+
+ @staticmethod
+ def list_collections() -> Any:
+ """Retrieve a list of all collections in the current project."""
+ return pinecone.list_collections()
+
+ @staticmethod
+ def query_vector(
+ index_name: str,
+ vector: list[Any],
+ query_id: str | None = None,
+ top_k: int = 10,
+ namespace: str | None = None,
+ query_filter: dict[str, str | float | int | bool | list[Any] |
dict[Any, Any]] | None = None,
+ include_values: bool | None = None,
+ include_metadata: bool | None = None,
+ sparse_vector: SparseValues | dict[str, list[float] | list[int]] |
None = None,
+ ) -> QueryResponse:
+ """
+ The Query operation searches a namespace, using a query vector.
+
+ It retrieves the ids of the most similar items in a namespace, along
with their similarity scores.
+ API reference: https://docs.pinecone.io/reference/query
+
+ :param index_name: The name of the index to query.
+ :param vector: The query vector.
+ :param query_id: The unique ID of the vector to be used as a query
vector.
+ :param top_k: The number of results to return.
+ :param namespace: The namespace to fetch vectors from. If not
specified, the default namespace is used.
+ :param query_filter: The filter to apply. See
https://www.pinecone.io/docs/metadata-filtering/
+ :param include_values: Whether to include the vector values in the
result.
+ :param include_metadata: Indicates whether metadata is included in the
response as well as the ids.
+ :param sparse_vector: sparse values of the query vector. Expected to
be either a SparseValues object or a dict
+ of the form: {'indices': List[int], 'values': List[float]}, where the
lists each have the same length.
+ """
+ index = pinecone.Index(index_name)
+ return index.query(
+ vector=vector,
+ id=query_id,
+ top_k=top_k,
+ namespace=namespace,
+ filter=query_filter,
+ include_values=include_values,
+ include_metadata=include_metadata,
+ sparse_vector=sparse_vector,
+ )
+
+ @staticmethod
+ def _chunks(iterable: list[Any], batch_size: int = 100) -> Any:
+ """Helper function to break an iterable into chunks of size
batch_size."""
+ it = iter(iterable)
+ chunk = tuple(itertools.islice(it, batch_size))
+ while chunk:
+ yield chunk
+ chunk = tuple(itertools.islice(it, batch_size))
+
+ def upsert_data_async(
+ self,
+ index_name: str,
+ data: list[tuple[Any]],
+ async_req: bool = False,
+ pool_threads: int | None = None,
+ ) -> None | list[Any]:
+ """
+ Upserts (insert/update) data into the Pinecone index.
+
+ :param index_name: Name of the index.
+ :param data: List of tuples to be upserted. Each tuple is of form (id,
vector, metadata).
+ Metadata is optional.
+ :param async_req: If True, upsert operations will be asynchronous.
+ :param pool_threads: Number of threads for parallel upserting. If
async_req is True, this must be provided.
+ """
+ responses = []
+ with pinecone.Index(index_name, pool_threads=pool_threads) as index:
+ if async_req and pool_threads:
+ async_results = [index.upsert(vectors=chunk, async_req=True)
for chunk in self._chunks(data)]
+ responses = [async_result.get() for async_result in
async_results]
+ else:
+ for chunk in self._chunks(data):
+ response = index.upsert(vectors=chunk)
+ responses.append(response)
+ return responses
+
+ @staticmethod
+ def describe_index_stats(
+ index_name: str,
+ stats_filter: dict[str, str | float | int | bool | list[Any] |
dict[Any, Any]] | None = None,
+ **kwargs: Any,
+ ) -> DescribeIndexStatsResponse:
+ """
+ Describes the index statistics.
+
+ Returns statistics about the index's contents. For example: The vector
count per
+ namespace and the number of dimensions.
+ API reference:
https://docs.pinecone.io/reference/describe_index_stats_post
+
+ :param index_name: Name of the index.
+ :param stats_filter: If this parameter is present, the operation only
returns statistics for vectors that
+ satisfy the filter. See
https://www.pinecone.io/docs/metadata-filtering/
+ """
+ index = pinecone.Index(index_name)
+ return index.describe_index_stats(filter=stats_filter, **kwargs)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index bcbda170ac..85af5b4d32 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -496,6 +496,7 @@ dogstatsd
donot
Dont
DOS'ing
+dotproduct
DownloadReportV
downscaling
downstreams
@@ -1668,6 +1669,7 @@ updateonly
Upsert
upsert
upserted
+upserting
upserts
Upsight
upstreams
diff --git a/tests/providers/pinecone/hooks/test_pinecone.py
b/tests/providers/pinecone/hooks/test_pinecone.py
index d358ca9485..fb076cc0a3 100644
--- a/tests/providers/pinecone/hooks/test_pinecone.py
+++ b/tests/providers/pinecone/hooks/test_pinecone.py
@@ -42,3 +42,92 @@ class TestPineconeHook:
mock_index.return_value.upsert = mock_upsert
self.pinecone_hook.upsert(self.index_name, data)
mock_upsert.assert_called_once_with(vectors=data, namespace="",
batch_size=None, show_progress=True)
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_indexes")
+ def test_list_indexes(self, mock_list_indexes):
+ """Test that the list_indexes method of PineconeHook is called
correctly."""
+ self.pinecone_hook.list_indexes()
+ mock_list_indexes.assert_called_once()
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index")
+ def test_create_index(self, mock_create_index):
+ """Test that the create_index method of PineconeHook is called with
correct arguments."""
+ self.pinecone_hook.create_index(index_name=self.index_name,
dimension=128)
+ mock_create_index.assert_called_once_with(index_name="test_index",
dimension=128)
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index")
+ def test_describe_index(self, mock_describe_index):
+ """Test that the describe_index method of PineconeHook is called with
correct arguments."""
+ self.pinecone_hook.describe_index(index_name=self.index_name)
+ mock_describe_index.assert_called_once_with(index_name=self.index_name)
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.delete_index")
+ def test_delete_index(self, mock_delete_index):
+ """Test that the delete_index method of PineconeHook is called with
the correct index name."""
+ self.pinecone_hook.delete_index(index_name="test_index")
+ mock_delete_index.assert_called_once_with(index_name="test_index")
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_collection")
+ def test_create_collection(self, mock_create_collection):
+ """
+ Test that the create_collection method of PineconeHook is called
correctly.
+ """
+ self.pinecone_hook.create_collection(collection_name="test_collection")
+
mock_create_collection.assert_called_once_with(collection_name="test_collection")
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.configure_index")
+ def test_configure_index(self, mock_configure_index):
+ """
+ Test that the configure_index method of PineconeHook is called
correctly.
+ """
+ self.pinecone_hook.configure_index(index_configuration={})
+ mock_configure_index.assert_called_once_with(index_configuration={})
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_collection")
+ def test_describe_collection(self, mock_describe_collection):
+ """
+ Test that the describe_collection method of PineconeHook is called
correctly.
+ """
+
self.pinecone_hook.describe_collection(collection_name="test_collection")
+
mock_describe_collection.assert_called_once_with(collection_name="test_collection")
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_collections")
+ def test_list_collections(self, mock_list_collections):
+ """
+ Test that the list_collections method of PineconeHook is called
correctly.
+ """
+ self.pinecone_hook.list_collections()
+ mock_list_collections.assert_called_once()
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.query_vector")
+ def test_query_vector(self, mock_query_vector):
+ """
+ Test that the query_vector method of PineconeHook is called correctly.
+ """
+ self.pinecone_hook.query_vector(vector=[1.0, 2.0, 3.0])
+ mock_query_vector.assert_called_once_with(vector=[1.0, 2.0, 3.0])
+
+ def test__chunks(self):
+ """
+ Test that the _chunks method of PineconeHook behaves as expected.
+ """
+ data = list(range(10))
+ chunked_data = list(self.pinecone_hook._chunks(data, 3))
+ assert chunked_data == [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.upsert_data_async")
+ def test_upsert_data_async_correctly(self, mock_upsert_data_async):
+ """
+ Test that the upsert_data_async method of PineconeHook is called
correctly.
+ """
+ data = [("id1", [1.0, 2.0, 3.0], {"meta": "data"})]
+ self.pinecone_hook.upsert_data_async(index_name="test_index",
data=data)
+
mock_upsert_data_async.assert_called_once_with(index_name="test_index",
data=data)
+
+
@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index_stats")
+ def test_describe_index_stats(self, mock_describe_index_stats):
+ """
+ Test that the describe_index_stats method of PineconeHook is called
correctly.
+ """
+ self.pinecone_hook.describe_index_stats(index_name="test_index")
+
mock_describe_index_stats.assert_called_once_with(index_name="test_index")