This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ff245591e Fix the argument type of input_vectors in pinecone upsert 
(#39688)
9ff245591e is described below

commit 9ff245591e3be55f63f5803258af93f82111785e
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Fri May 17 22:49:39 2024 +0545

    Fix the argument type of input_vectors in pinecone upsert (#39688)
    
    * Fix the argument type of input_vectors
    
    * Fix typing and docstring
---
 airflow/providers/pinecone/hooks/pinecone.py              |  3 ++-
 airflow/providers/pinecone/operators/pinecone.py          |  8 +++++---
 .../system/providers/pinecone/example_pinecone_cohere.py  | 15 +++++++++------
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/pinecone/hooks/pinecone.py 
b/airflow/providers/pinecone/hooks/pinecone.py
index 35aa66c320..b5e73ae4c6 100644
--- a/airflow/providers/pinecone/hooks/pinecone.py
+++ b/airflow/providers/pinecone/hooks/pinecone.py
@@ -29,6 +29,7 @@ from pinecone import Pinecone, PodSpec, ServerlessSpec
 from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
+    from pinecone import Vector
     from pinecone.core.client.model.sparse_values import SparseValues
     from pinecone.core.client.models import DescribeIndexStatsResponse, 
QueryResponse, UpsertResponse
 
@@ -137,7 +138,7 @@ class PineconeHook(BaseHook):
     def upsert(
         self,
         index_name: str,
-        vectors: list[Any],
+        vectors: list[Vector] | list[tuple] | list[dict],
         namespace: str = "",
         batch_size: int | None = None,
         show_progress: bool = True,
diff --git a/airflow/providers/pinecone/operators/pinecone.py 
b/airflow/providers/pinecone/operators/pinecone.py
index bb3d44214d..70711e0623 100644
--- a/airflow/providers/pinecone/operators/pinecone.py
+++ b/airflow/providers/pinecone/operators/pinecone.py
@@ -25,6 +25,8 @@ from airflow.providers.pinecone.hooks.pinecone import 
PineconeHook
 from airflow.utils.context import Context
 
 if TYPE_CHECKING:
+    from pinecone import Vector
+
     from airflow.utils.context import Context
 
 
@@ -38,8 +40,8 @@ class PineconeIngestOperator(BaseOperator):
 
     :param conn_id: The connection id to use when connecting to Pinecone.
     :param index_name: Name of the Pinecone index.
-    :param input_vectors: Data to be ingested, in the form of a list of tuples 
where each tuple
-        contains (id, vector_embedding, metadata).
+    :param input_vectors: Data to be ingested, in the form of a list of 
vectors, list of tuples,
+        or list of dictionaries.
     :param namespace: The namespace to write to. If not specified, the default 
namespace is used.
     :param batch_size: The number of vectors to upsert in each batch.
     :param upsert_kwargs: .. seealso:: 
https://docs.pinecone.io/reference/upsert
@@ -52,7 +54,7 @@ class PineconeIngestOperator(BaseOperator):
         *,
         conn_id: str = PineconeHook.default_conn_name,
         index_name: str,
-        input_vectors: list[tuple],
+        input_vectors: list[Vector] | list[tuple] | list[dict],
         namespace: str = "",
         batch_size: int | None = None,
         upsert_kwargs: dict | None = None,
diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py 
b/tests/system/providers/pinecone/example_pinecone_cohere.py
index c74a376f61..80e6766484 100644
--- a/tests/system/providers/pinecone/example_pinecone_cohere.py
+++ b/tests/system/providers/pinecone/example_pinecone_cohere.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 import os
-import time
 from datetime import datetime
 
 from airflow import DAG
@@ -46,19 +45,23 @@ with DAG(
         hook = PineconeHook()
         pod_spec = hook.get_pod_spec_obj()
         hook.create_index(index_name=index_name, dimension=768, spec=pod_spec)
-        time.sleep(60)
 
     embed_task = CohereEmbeddingOperator(
         task_id="embed_task",
         input_text=data,
     )
 
+    @task
+    def transform_output(embedding_output) -> list[dict]:
+        # Convert each embedding to a map with an ID and the embedding vector
+        return [dict(id=str(i), values=embedding) for i, embedding in 
enumerate(embedding_output)]
+
+    transformed_output = transform_output(embed_task.output)
+
     perform_ingestion = PineconeIngestOperator(
         task_id="perform_ingestion",
         index_name=index_name,
-        input_vectors=[
-            ("id1", embed_task.output),
-        ],
+        input_vectors=transformed_output,
         namespace=namespace,
         batch_size=1,
     )
@@ -71,7 +74,7 @@ with DAG(
         hook = PineconeHook()
         hook.delete_index(index_name=index_name)
 
-    create_index() >> embed_task >> perform_ingestion >> delete_index()
+    create_index() >> embed_task >> transformed_output >> perform_ingestion >> 
delete_index()
 
 from tests.system.utils import get_test_run  # noqa: E402
 

Reply via email to