This is an automated email from the ASF dual-hosted git repository.

rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a7e521197e Make Cohere provider AF3 compatible (#51396)
3a7e521197e is described below

commit 3a7e521197e43d6c7444c9170c258588906918da
Author: Rahul Vats <[email protected]>
AuthorDate: Fri Jun 6 11:40:23 2025 +0530

    Make Cohere provider AF3 compatible (#51396)
    
    * make cohere provider AF3 compatible
---
 providers/cohere/src/airflow/providers/cohere/hooks/cohere.py  |  9 +++++----
 .../cohere/src/airflow/providers/cohere/operators/embedding.py |  8 +++++---
 providers/cohere/tests/unit/cohere/operators/test_embedding.py | 10 +++++-----
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py 
b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
index 16b313c1d71..bfaba2d658a 100644
--- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
+++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
@@ -29,7 +29,7 @@ from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
     from cohere.core.request_options import RequestOptions
-    from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings
+    from cohere.types import ChatMessages
 
 
 logger = logging.getLogger(__name__)
@@ -91,7 +91,7 @@ class CohereHook(BaseHook):
 
     def create_embeddings(
         self, texts: list[str], model: str = "embed-multilingual-v3.0"
-    ) -> EmbedByTypeResponseEmbeddings:
+    ) -> list[list[float]]:
         logger.info("Creating embeddings with model: embed-multilingual-v3.0")
         response = self.get_conn().embed(
             texts=texts,
@@ -100,8 +100,9 @@ class CohereHook(BaseHook):
             embedding_types=["float"],
             request_options=self.request_options,
         )
-        embeddings = response.embeddings
-        return embeddings
+        if response.embeddings.float_ is None:
+            raise ValueError("Embeddings response is missing float_ field")
+        return response.embeddings.float_
 
     @classmethod
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
diff --git 
a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py 
b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
index b06f13ab021..1504858fc21 100644
--- a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
+++ b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
@@ -26,7 +26,6 @@ from airflow.providers.cohere.hooks.cohere import CohereHook
 
 if TYPE_CHECKING:
     from cohere.core.request_options import RequestOptions
-    from cohere.types import EmbedByTypeResponseEmbeddings
 
     try:
         from airflow.sdk.definitions.context import Context
@@ -91,6 +90,9 @@ class CohereEmbeddingOperator(BaseOperator):
             request_options=self.request_options,
         )
 
-    def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
+    def execute(self, context: Context) -> list[list[float]]:
         """Embed texts using Cohere embed services."""
-        return self.hook.create_embeddings(self.input_text)
+        embedding_response = self.hook.create_embeddings(self.input_text)
+
+        # Extract just the embeddings list, which is serializable
+        return embedding_response
diff --git a/providers/cohere/tests/unit/cohere/operators/test_embedding.py 
b/providers/cohere/tests/unit/cohere/operators/test_embedding.py
index 640690f1f1d..1b6bf810e93 100644
--- a/providers/cohere/tests/unit/cohere/operators/test_embedding.py
+++ b/providers/cohere/tests/unit/cohere/operators/test_embedding.py
@@ -27,12 +27,12 @@ from airflow.providers.cohere.operators.embedding import 
CohereEmbeddingOperator
 def test_cohere_embedding_operator(cohere_client, get_connection):
     """
     Test Cohere client is getting called with the correct key and that
-     the execute methods returns expected response.
+    the execute method returns expected response.
     """
-    embedded_obj = [1, 2, 3]
+    embedded_obj = [[1.0, 2.0, 3.0]]
 
-    class resp:
-        embeddings = embedded_obj
+    mock_response = MagicMock()
+    mock_response.embeddings.float_ = embedded_obj
 
     api_key = "test"
     base_url = "http://some_host.com";
@@ -43,7 +43,7 @@ def test_cohere_embedding_operator(cohere_client, 
get_connection):
     get_connection.return_value = Connection(conn_type="cohere", 
password=api_key, host=base_url)
     client_obj = MagicMock()
     cohere_client.return_value = client_obj
-    client_obj.embed.return_value = resp
+    client_obj.embed.return_value = mock_response
 
     op = CohereEmbeddingOperator(
         task_id="embed",

Reply via email to