This is an automated email from the ASF dual-hosted git repository.
rahulvats pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3a7e521197e Make Cohere provider AF3 compatible (#51396)
3a7e521197e is described below
commit 3a7e521197e43d6c7444c9170c258588906918da
Author: Rahul Vats <[email protected]>
AuthorDate: Fri Jun 6 11:40:23 2025 +0530
Make Cohere provider AF3 compatible (#51396)
* make cohere provider AF3 compatible
---
providers/cohere/src/airflow/providers/cohere/hooks/cohere.py | 9 +++++----
.../cohere/src/airflow/providers/cohere/operators/embedding.py | 8 +++++---
providers/cohere/tests/unit/cohere/operators/test_embedding.py | 10 +++++-----
3 files changed, 15 insertions(+), 12 deletions(-)
diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
index 16b313c1d71..bfaba2d658a 100644
--- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
+++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py
@@ -29,7 +29,7 @@ from airflow.hooks.base import BaseHook
if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
- from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings
+ from cohere.types import ChatMessages
logger = logging.getLogger(__name__)
@@ -91,7 +91,7 @@ class CohereHook(BaseHook):
def create_embeddings(
self, texts: list[str], model: str = "embed-multilingual-v3.0"
- ) -> EmbedByTypeResponseEmbeddings:
+ ) -> list[list[float]]:
logger.info("Creating embeddings with model: embed-multilingual-v3.0")
response = self.get_conn().embed(
texts=texts,
@@ -100,8 +100,9 @@ class CohereHook(BaseHook):
embedding_types=["float"],
request_options=self.request_options,
)
- embeddings = response.embeddings
- return embeddings
+ if response.embeddings.float_ is None:
+ raise ValueError("Embeddings response is missing float_ field")
+ return response.embeddings.float_
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
diff --git
a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
index b06f13ab021..1504858fc21 100644
--- a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
+++ b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py
@@ -26,7 +26,6 @@ from airflow.providers.cohere.hooks.cohere import CohereHook
if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
- from cohere.types import EmbedByTypeResponseEmbeddings
try:
from airflow.sdk.definitions.context import Context
@@ -91,6 +90,9 @@ class CohereEmbeddingOperator(BaseOperator):
request_options=self.request_options,
)
- def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
+ def execute(self, context: Context) -> list[list[float]]:
"""Embed texts using Cohere embed services."""
- return self.hook.create_embeddings(self.input_text)
+ embedding_response = self.hook.create_embeddings(self.input_text)
+
+ # Extract just the embeddings list, which is serializable
+ return embedding_response
diff --git a/providers/cohere/tests/unit/cohere/operators/test_embedding.py
b/providers/cohere/tests/unit/cohere/operators/test_embedding.py
index 640690f1f1d..1b6bf810e93 100644
--- a/providers/cohere/tests/unit/cohere/operators/test_embedding.py
+++ b/providers/cohere/tests/unit/cohere/operators/test_embedding.py
@@ -27,12 +27,12 @@ from airflow.providers.cohere.operators.embedding import
CohereEmbeddingOperator
def test_cohere_embedding_operator(cohere_client, get_connection):
"""
Test Cohere client is getting called with the correct key and that
- the execute methods returns expected response.
+ the execute method returns expected response.
"""
- embedded_obj = [1, 2, 3]
+ embedded_obj = [[1.0, 2.0, 3.0]]
- class resp:
- embeddings = embedded_obj
+ mock_response = MagicMock()
+ mock_response.embeddings.float_ = embedded_obj
api_key = "test"
base_url = "http://some_host.com"
@@ -43,7 +43,7 @@ def test_cohere_embedding_operator(cohere_client,
get_connection):
get_connection.return_value = Connection(conn_type="cohere",
password=api_key, host=base_url)
client_obj = MagicMock()
cohere_client.return_value = client_obj
- client_obj.embed.return_value = resp
+ client_obj.embed.return_value = mock_response
op = CohereEmbeddingOperator(
task_id="embed",