This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a78648848d Update Cohere to 5.13.4 v2 API (#45267)
2a78648848d is described below

commit 2a78648848d6f8edeed0e5df8c7120902965d1ce
Author: Albert Okiri <[email protected]>
AuthorDate: Sun Dec 29 23:44:01 2024 +0300

    Update Cohere to 5.13.4 v2 API (#45267)
---
 generated/provider_dependencies.json               |  2 +-
 .../src/airflow/providers/cohere/hooks/cohere.py   | 70 +++++++++++++++++-----
 .../providers/cohere/operators/embedding.py        | 25 +++++++-
 .../src/airflow/providers/cohere/provider.yaml     |  2 +-
 providers/tests/cohere/hooks/test_cohere.py        | 13 ++--
 providers/tests/cohere/operators/test_embedding.py | 18 +++---
 6 files changed, 95 insertions(+), 35 deletions(-)

diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 96a1b420e35..f58c010b3f5 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -383,7 +383,7 @@
   "cohere": {
     "deps": [
       "apache-airflow>=2.9.0",
-      "cohere>=4.37,<5"
+      "cohere>=5.13.4"
     ],
     "devel-deps": [],
     "plugins": [],
diff --git a/providers/src/airflow/providers/cohere/hooks/cohere.py 
b/providers/src/airflow/providers/cohere/hooks/cohere.py
index 2ce40c74d1e..b2d8c5d4476 100644
--- a/providers/src/airflow/providers/cohere/hooks/cohere.py
+++ b/providers/src/airflow/providers/cohere/hooks/cohere.py
@@ -1,4 +1,3 @@
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,25 +14,38 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
+import logging
+import warnings
 from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 import cohere
+from cohere.types import UserChatMessageV2
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
+if TYPE_CHECKING:
+    from cohere.core.request_options import RequestOptions
+    from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings
+
+
+logger = logging.getLogger(__name__)
+
 
 class CohereHook(BaseHook):
     """
-    Use Cohere Python SDK to interact with Cohere platform.
+    Use Cohere Python SDK to interact with Cohere platform using API v2.
 
     .. seealso:: https://docs.cohere.com/docs
 
     :param conn_id: :ref:`Cohere connection id <howto/connection:cohere>`
-    :param timeout: Request timeout in seconds.
-    :param max_retries: Maximal number of retries for requests.
+    :param timeout: Request timeout in seconds. Optional.
+    :param max_retries: Maximal number of retries for requests. Deprecated, 
use request_options instead. Optional.
+    :param request_options: Dictionary for function-specific request 
configuration. Optional.
     """
 
     conn_name_attr = "conn_id"
@@ -46,23 +58,45 @@ class CohereHook(BaseHook):
         conn_id: str = default_conn_name,
         timeout: int | None = None,
         max_retries: int | None = None,
+        request_options: RequestOptions | None = None,
     ) -> None:
         super().__init__()
         self.conn_id = conn_id
         self.timeout = timeout
         self.max_retries = max_retries
+        self.request_options = request_options
+
+        if self.max_retries:
+            warnings.warn(
+                "Argument `max_retries` is deprecated. Use `request_options` 
dict for function-specific request configuration.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            if self.request_options is None:
+                self.request_options = {"max_retries": self.max_retries}
+            else:
+                self.request_options.update({"max_retries": self.max_retries})
 
     @cached_property
-    def get_conn(self) -> cohere.Client:  # type: ignore[override]
+    def get_conn(self) -> cohere.ClientV2:  # type: ignore[override]
         conn = self.get_connection(self.conn_id)
-        return cohere.Client(
-            api_key=conn.password, timeout=self.timeout, 
max_retries=self.max_retries, api_url=conn.host
+        return cohere.ClientV2(
+            api_key=conn.password,
+            timeout=self.timeout,
+            base_url=conn.host or None,
         )
 
     def create_embeddings(
-        self, texts: list[str], model: str = "embed-multilingual-v2.0"
-    ) -> list[list[float]]:
-        response = self.get_conn.embed(texts=texts, model=model)
+        self, texts: list[str], model: str = "embed-multilingual-v3.0"
+    ) -> EmbedByTypeResponseEmbeddings:
+        logger.info("Creating embeddings with model: embed-multilingual-v3.0")
+        response = self.get_conn.embed(
+            texts=texts,
+            model=model,
+            input_type="search_document",
+            embedding_types=["float"],
+            request_options=self.request_options,
+        )
         embeddings = response.embeddings
         return embeddings
 
@@ -75,9 +109,15 @@ class CohereHook(BaseHook):
             },
         }
 
-    def test_connection(self) -> tuple[bool, str]:
+    def test_connection(
+        self,
+        model: str = "command-r-plus-08-2024",
+        messages: ChatMessages | None = None,
+    ) -> tuple[bool, str]:
         try:
-            self.get_conn.generate("Test", max_tokens=10)
-            return True, "Connection established"
+            if messages is None:
+                messages = [UserChatMessageV2(role="user", content="hello 
world!")]
+            self.get_conn.chat(model=model, messages=messages)
+            return True, "Connection successfully established."
         except Exception as e:
-            return False, str(e)
+            return False, f"Unexpected error: {str(e)}"
diff --git a/providers/src/airflow/providers/cohere/operators/embedding.py 
b/providers/src/airflow/providers/cohere/operators/embedding.py
index 85a585a9f73..c5de22b9b58 100644
--- a/providers/src/airflow/providers/cohere/operators/embedding.py
+++ b/providers/src/airflow/providers/cohere/operators/embedding.py
@@ -25,6 +25,9 @@ from airflow.models import BaseOperator
 from airflow.providers.cohere.hooks.cohere import CohereHook
 
 if TYPE_CHECKING:
+    from cohere.core.request_options import RequestOptions
+    from cohere.types import EmbedByTypeResponseEmbeddings
+
     from airflow.utils.context import Context
 
 
@@ -41,6 +44,17 @@ class CohereEmbeddingOperator(BaseOperator):
         information for Cohere. Defaults to "cohere_default".
     :param timeout: Timeout in seconds for Cohere API.
     :param max_retries: Number of times to retry before failing.
+    :param request_options: Request-specific configuration.
+        Fields:
+        - timeout_in_seconds: int. The number of seconds to await an API call 
before timing out.
+
+        - max_retries: int. The max number of retries to attempt if the API 
call fails.
+
+        - additional_headers: typing.Dict[str, typing.Any]. A dictionary 
containing additional parameters to spread into the request's header dict
+
+        - additional_query_parameters: typing.Dict[str, typing.Any]. A 
dictionary containing additional parameters to spread into the request's query 
parameters dict
+
+        - additional_body_parameters: typing.Dict[str, typing.Any]. A 
dictionary containing additional parameters to spread into the request's body 
parameters dict
     """
 
     template_fields: Sequence[str] = ("input_text",)
@@ -51,6 +65,7 @@ class CohereEmbeddingOperator(BaseOperator):
         conn_id: str = CohereHook.default_conn_name,
         timeout: int | None = None,
         max_retries: int | None = None,
+        request_options: RequestOptions | None = None,
         **kwargs: Any,
     ):
         super().__init__(**kwargs)
@@ -60,12 +75,18 @@ class CohereEmbeddingOperator(BaseOperator):
         self.input_text = input_text
         self.timeout = timeout
         self.max_retries = max_retries
+        self.request_options = request_options
 
     @cached_property
     def hook(self) -> CohereHook:
         """Return an instance of the CohereHook."""
-        return CohereHook(conn_id=self.conn_id, timeout=self.timeout, 
max_retries=self.max_retries)
+        return CohereHook(
+            conn_id=self.conn_id,
+            timeout=self.timeout,
+            max_retries=self.max_retries,
+            request_options=self.request_options,
+        )
 
-    def execute(self, context: Context) -> list[list[float]]:
+    def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
         """Embed texts using Cohere embed services."""
         return self.hook.create_embeddings(self.input_text)
diff --git a/providers/src/airflow/providers/cohere/provider.yaml 
b/providers/src/airflow/providers/cohere/provider.yaml
index 341645ed915..4c8b1fea3d4 100644
--- a/providers/src/airflow/providers/cohere/provider.yaml
+++ b/providers/src/airflow/providers/cohere/provider.yaml
@@ -47,7 +47,7 @@ integrations:
 
 dependencies:
   - apache-airflow>=2.9.0
-  - cohere>=4.37,<5
+  - cohere>=5.13.4
 
 hooks:
   - integration-name: Cohere
diff --git a/providers/tests/cohere/hooks/test_cohere.py 
b/providers/tests/cohere/hooks/test_cohere.py
index 28aef3ebaf7..00c73cd1f29 100644
--- a/providers/tests/cohere/hooks/test_cohere.py
+++ b/providers/tests/cohere/hooks/test_cohere.py
@@ -31,19 +31,16 @@ class TestCohereHook:
 
     def test__get_api_key(self):
         api_key = "test"
-        api_url = "http://some_host.com";
+        base_url = "http://some_host.com";
         timeout = 150
-        max_retries = 5
         with (
             patch.object(
                 CohereHook,
                 "get_connection",
-                return_value=Connection(conn_type="cohere", password=api_key, 
host=api_url),
+                return_value=Connection(conn_type="cohere", password=api_key, 
host=base_url),
             ),
-            patch("cohere.Client") as client,
+            patch("cohere.ClientV2") as client,
         ):
-            hook = CohereHook(timeout=timeout, max_retries=max_retries)
+            hook = CohereHook(timeout=timeout)
             _ = hook.get_conn
-            client.assert_called_once_with(
-                api_key=api_key, timeout=timeout, max_retries=max_retries, 
api_url=api_url
-            )
+            client.assert_called_once_with(api_key=api_key, timeout=timeout, 
base_url=base_url)
diff --git a/providers/tests/cohere/operators/test_embedding.py 
b/providers/tests/cohere/operators/test_embedding.py
index 32dd83aa261..640690f1f1d 100644
--- a/providers/tests/cohere/operators/test_embedding.py
+++ b/providers/tests/cohere/operators/test_embedding.py
@@ -23,7 +23,7 @@ from airflow.providers.cohere.operators.embedding import 
CohereEmbeddingOperator
 
 
 @patch("airflow.providers.cohere.hooks.cohere.CohereHook.get_connection")
-@patch("cohere.Client")
+@patch("cohere.ClientV2")
 def test_cohere_embedding_operator(cohere_client, get_connection):
     """
     Test Cohere client is getting called with the correct key and that
@@ -35,22 +35,24 @@ def test_cohere_embedding_operator(cohere_client, 
get_connection):
         embeddings = embedded_obj
 
     api_key = "test"
-    api_url = "http://some_host.com";
+    base_url = "http://some_host.com";
     timeout = 150
-    max_retries = 5
     texts = ["On Kernel-Target Alignment. We describe a family of global 
optimization procedures"]
+    request_options = None
 
-    get_connection.return_value = Connection(conn_type="cohere", 
password=api_key, host=api_url)
+    get_connection.return_value = Connection(conn_type="cohere", 
password=api_key, host=base_url)
     client_obj = MagicMock()
     cohere_client.return_value = client_obj
     client_obj.embed.return_value = resp
 
     op = CohereEmbeddingOperator(
-        task_id="embed", conn_id="some_conn", input_text=texts, 
timeout=timeout, max_retries=max_retries
+        task_id="embed",
+        conn_id="some_conn",
+        input_text=texts,
+        timeout=timeout,
+        request_options=request_options,
     )
 
     val = op.execute(context={})
-    cohere_client.assert_called_once_with(
-        api_key=api_key, api_url=api_url, timeout=timeout, 
max_retries=max_retries
-    )
+    cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url, 
timeout=timeout)
     assert val == embedded_obj

Reply via email to