This is an automated email from the ASF dual-hosted git repository.
utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9ff245591e Fix the argument type of input_vectors in pinecone upsert
(#39688)
9ff245591e is described below
commit 9ff245591e3be55f63f5803258af93f82111785e
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Fri May 17 22:49:39 2024 +0545
Fix the argument type of input_vectors in pinecone upsert (#39688)
* Fix the argument type of input_vectors
* Fix typing and docstring
---
airflow/providers/pinecone/hooks/pinecone.py | 3 ++-
airflow/providers/pinecone/operators/pinecone.py | 8 +++++---
.../system/providers/pinecone/example_pinecone_cohere.py | 15 +++++++++------
3 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/pinecone/hooks/pinecone.py
b/airflow/providers/pinecone/hooks/pinecone.py
index 35aa66c320..b5e73ae4c6 100644
--- a/airflow/providers/pinecone/hooks/pinecone.py
+++ b/airflow/providers/pinecone/hooks/pinecone.py
@@ -29,6 +29,7 @@ from pinecone import Pinecone, PodSpec, ServerlessSpec
from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
+ from pinecone import Vector
from pinecone.core.client.model.sparse_values import SparseValues
from pinecone.core.client.models import DescribeIndexStatsResponse,
QueryResponse, UpsertResponse
@@ -137,7 +138,7 @@ class PineconeHook(BaseHook):
def upsert(
self,
index_name: str,
- vectors: list[Any],
+ vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
show_progress: bool = True,
diff --git a/airflow/providers/pinecone/operators/pinecone.py
b/airflow/providers/pinecone/operators/pinecone.py
index bb3d44214d..70711e0623 100644
--- a/airflow/providers/pinecone/operators/pinecone.py
+++ b/airflow/providers/pinecone/operators/pinecone.py
@@ -25,6 +25,8 @@ from airflow.providers.pinecone.hooks.pinecone import
PineconeHook
from airflow.utils.context import Context
if TYPE_CHECKING:
+ from pinecone import Vector
+
from airflow.utils.context import Context
@@ -38,8 +40,8 @@ class PineconeIngestOperator(BaseOperator):
:param conn_id: The connection id to use when connecting to Pinecone.
:param index_name: Name of the Pinecone index.
- :param input_vectors: Data to be ingested, in the form of a list of tuples
where each tuple
- contains (id, vector_embedding, metadata).
+ :param input_vectors: Data to be ingested, in the form of a list of
vectors, list of tuples,
+ or list of dictionaries.
:param namespace: The namespace to write to. If not specified, the default
namespace is used.
:param batch_size: The number of vectors to upsert in each batch.
:param upsert_kwargs: .. seealso::
https://docs.pinecone.io/reference/upsert
@@ -52,7 +54,7 @@ class PineconeIngestOperator(BaseOperator):
*,
conn_id: str = PineconeHook.default_conn_name,
index_name: str,
- input_vectors: list[tuple],
+ input_vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
upsert_kwargs: dict | None = None,
diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py
b/tests/system/providers/pinecone/example_pinecone_cohere.py
index c74a376f61..80e6766484 100644
--- a/tests/system/providers/pinecone/example_pinecone_cohere.py
+++ b/tests/system/providers/pinecone/example_pinecone_cohere.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import os
-import time
from datetime import datetime
from airflow import DAG
@@ -46,19 +45,23 @@ with DAG(
hook = PineconeHook()
pod_spec = hook.get_pod_spec_obj()
hook.create_index(index_name=index_name, dimension=768, spec=pod_spec)
- time.sleep(60)
embed_task = CohereEmbeddingOperator(
task_id="embed_task",
input_text=data,
)
+ @task
+ def transform_output(embedding_output) -> list[dict]:
+ # Convert each embedding to a map with an ID and the embedding vector
+ return [dict(id=str(i), values=embedding) for i, embedding in
enumerate(embedding_output)]
+
+ transformed_output = transform_output(embed_task.output)
+
perform_ingestion = PineconeIngestOperator(
task_id="perform_ingestion",
index_name=index_name,
- input_vectors=[
- ("id1", embed_task.output),
- ],
+ input_vectors=transformed_output,
namespace=namespace,
batch_size=1,
)
@@ -71,7 +74,7 @@ with DAG(
hook = PineconeHook()
hook.delete_index(index_name=index_name)
- create_index() >> embed_task >> perform_ingestion >> delete_index()
+ create_index() >> embed_task >> transformed_output >> perform_ingestion >>
delete_index()
from tests.system.utils import get_test_run # noqa: E402