This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2a78648848d Update Cohere to 5.13.4 v2 API (#45267)
2a78648848d is described below
commit 2a78648848d6f8edeed0e5df8c7120902965d1ce
Author: Albert Okiri <[email protected]>
AuthorDate: Sun Dec 29 23:44:01 2024 +0300
Update Cohere to 5.13.4 v2 API (#45267)
---
generated/provider_dependencies.json | 2 +-
.../src/airflow/providers/cohere/hooks/cohere.py | 70 +++++++++++++++++-----
.../providers/cohere/operators/embedding.py | 25 +++++++-
.../src/airflow/providers/cohere/provider.yaml | 2 +-
providers/tests/cohere/hooks/test_cohere.py | 13 ++--
providers/tests/cohere/operators/test_embedding.py | 18 +++---
6 files changed, 95 insertions(+), 35 deletions(-)
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 96a1b420e35..f58c010b3f5 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -383,7 +383,7 @@
"cohere": {
"deps": [
"apache-airflow>=2.9.0",
- "cohere>=4.37,<5"
+ "cohere>=5.13.4"
],
"devel-deps": [],
"plugins": [],
diff --git a/providers/src/airflow/providers/cohere/hooks/cohere.py
b/providers/src/airflow/providers/cohere/hooks/cohere.py
index 2ce40c74d1e..b2d8c5d4476 100644
--- a/providers/src/airflow/providers/cohere/hooks/cohere.py
+++ b/providers/src/airflow/providers/cohere/hooks/cohere.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,25 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
+import logging
+import warnings
from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Any
import cohere
+from cohere.types import UserChatMessageV2
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
+if TYPE_CHECKING:
+ from cohere.core.request_options import RequestOptions
+ from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings
+
+
+logger = logging.getLogger(__name__)
+
class CohereHook(BaseHook):
"""
- Use Cohere Python SDK to interact with Cohere platform.
+ Use Cohere Python SDK to interact with Cohere platform using API v2.
.. seealso:: https://docs.cohere.com/docs
:param conn_id: :ref:`Cohere connection id <howto/connection:cohere>`
- :param timeout: Request timeout in seconds.
- :param max_retries: Maximal number of retries for requests.
+ :param timeout: Request timeout in seconds. Optional.
+ :param max_retries: Maximal number of retries for requests. Deprecated,
use request_options instead. Optional.
+ :param request_options: Dictionary for function-specific request
configuration. Optional.
"""
conn_name_attr = "conn_id"
@@ -46,23 +58,45 @@ class CohereHook(BaseHook):
conn_id: str = default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
+ request_options: RequestOptions | None = None,
) -> None:
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.max_retries = max_retries
+ self.request_options = request_options
+
+ if self.max_retries:
+ warnings.warn(
+ "Argument `max_retries` is deprecated. Use `request_options`
dict for function-specific request configuration.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ if self.request_options is None:
+ self.request_options = {"max_retries": self.max_retries}
+ else:
+ self.request_options.update({"max_retries": self.max_retries})
@cached_property
- def get_conn(self) -> cohere.Client: # type: ignore[override]
+ def get_conn(self) -> cohere.ClientV2: # type: ignore[override]
conn = self.get_connection(self.conn_id)
- return cohere.Client(
- api_key=conn.password, timeout=self.timeout,
max_retries=self.max_retries, api_url=conn.host
+ return cohere.ClientV2(
+ api_key=conn.password,
+ timeout=self.timeout,
+ base_url=conn.host or None,
)
def create_embeddings(
- self, texts: list[str], model: str = "embed-multilingual-v2.0"
- ) -> list[list[float]]:
- response = self.get_conn.embed(texts=texts, model=model)
+ self, texts: list[str], model: str = "embed-multilingual-v3.0"
+ ) -> EmbedByTypeResponseEmbeddings:
+ logger.info("Creating embeddings with model: embed-multilingual-v3.0")
+ response = self.get_conn.embed(
+ texts=texts,
+ model=model,
+ input_type="search_document",
+ embedding_types=["float"],
+ request_options=self.request_options,
+ )
embeddings = response.embeddings
return embeddings
@@ -75,9 +109,15 @@ class CohereHook(BaseHook):
},
}
- def test_connection(self) -> tuple[bool, str]:
+ def test_connection(
+ self,
+ model: str = "command-r-plus-08-2024",
+ messages: ChatMessages | None = None,
+ ) -> tuple[bool, str]:
try:
- self.get_conn.generate("Test", max_tokens=10)
- return True, "Connection established"
+ if messages is None:
+ messages = [UserChatMessageV2(role="user", content="hello
world!")]
+ self.get_conn.chat(model=model, messages=messages)
+ return True, "Connection successfully established."
except Exception as e:
- return False, str(e)
+ return False, f"Unexpected error: {str(e)}"
diff --git a/providers/src/airflow/providers/cohere/operators/embedding.py
b/providers/src/airflow/providers/cohere/operators/embedding.py
index 85a585a9f73..c5de22b9b58 100644
--- a/providers/src/airflow/providers/cohere/operators/embedding.py
+++ b/providers/src/airflow/providers/cohere/operators/embedding.py
@@ -25,6 +25,9 @@ from airflow.models import BaseOperator
from airflow.providers.cohere.hooks.cohere import CohereHook
if TYPE_CHECKING:
+ from cohere.core.request_options import RequestOptions
+ from cohere.types import EmbedByTypeResponseEmbeddings
+
from airflow.utils.context import Context
@@ -41,6 +44,17 @@ class CohereEmbeddingOperator(BaseOperator):
information for Cohere. Defaults to "cohere_default".
:param timeout: Timeout in seconds for Cohere API.
:param max_retries: Number of times to retry before failing.
+ :param request_options: Request-specific configuration.
+ Fields:
+ - timeout_in_seconds: int. The number of seconds to await an API call
before timing out.
+
+ - max_retries: int. The max number of retries to attempt if the API
call fails.
+
+ - additional_headers: typing.Dict[str, typing.Any]. A dictionary
containing additional parameters to spread into the request's header dict
+
+ - additional_query_parameters: typing.Dict[str, typing.Any]. A
dictionary containing additional parameters to spread into the request's query
parameters dict
+
+ - additional_body_parameters: typing.Dict[str, typing.Any]. A
dictionary containing additional parameters to spread into the request's body
parameters dict
"""
template_fields: Sequence[str] = ("input_text",)
@@ -51,6 +65,7 @@ class CohereEmbeddingOperator(BaseOperator):
conn_id: str = CohereHook.default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
+ request_options: RequestOptions | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
@@ -60,12 +75,18 @@ class CohereEmbeddingOperator(BaseOperator):
self.input_text = input_text
self.timeout = timeout
self.max_retries = max_retries
+ self.request_options = request_options
@cached_property
def hook(self) -> CohereHook:
"""Return an instance of the CohereHook."""
- return CohereHook(conn_id=self.conn_id, timeout=self.timeout,
max_retries=self.max_retries)
+ return CohereHook(
+ conn_id=self.conn_id,
+ timeout=self.timeout,
+ max_retries=self.max_retries,
+ request_options=self.request_options,
+ )
- def execute(self, context: Context) -> list[list[float]]:
+ def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
"""Embed texts using Cohere embed services."""
return self.hook.create_embeddings(self.input_text)
diff --git a/providers/src/airflow/providers/cohere/provider.yaml
b/providers/src/airflow/providers/cohere/provider.yaml
index 341645ed915..4c8b1fea3d4 100644
--- a/providers/src/airflow/providers/cohere/provider.yaml
+++ b/providers/src/airflow/providers/cohere/provider.yaml
@@ -47,7 +47,7 @@ integrations:
dependencies:
- apache-airflow>=2.9.0
- - cohere>=4.37,<5
+ - cohere>=5.13.4
hooks:
- integration-name: Cohere
diff --git a/providers/tests/cohere/hooks/test_cohere.py
b/providers/tests/cohere/hooks/test_cohere.py
index 28aef3ebaf7..00c73cd1f29 100644
--- a/providers/tests/cohere/hooks/test_cohere.py
+++ b/providers/tests/cohere/hooks/test_cohere.py
@@ -31,19 +31,16 @@ class TestCohereHook:
def test__get_api_key(self):
api_key = "test"
- api_url = "http://some_host.com"
+ base_url = "http://some_host.com"
timeout = 150
- max_retries = 5
with (
patch.object(
CohereHook,
"get_connection",
- return_value=Connection(conn_type="cohere", password=api_key,
host=api_url),
+ return_value=Connection(conn_type="cohere", password=api_key,
host=base_url),
),
- patch("cohere.Client") as client,
+ patch("cohere.ClientV2") as client,
):
- hook = CohereHook(timeout=timeout, max_retries=max_retries)
+ hook = CohereHook(timeout=timeout)
_ = hook.get_conn
- client.assert_called_once_with(
- api_key=api_key, timeout=timeout, max_retries=max_retries,
api_url=api_url
- )
+ client.assert_called_once_with(api_key=api_key, timeout=timeout,
base_url=base_url)
diff --git a/providers/tests/cohere/operators/test_embedding.py
b/providers/tests/cohere/operators/test_embedding.py
index 32dd83aa261..640690f1f1d 100644
--- a/providers/tests/cohere/operators/test_embedding.py
+++ b/providers/tests/cohere/operators/test_embedding.py
@@ -23,7 +23,7 @@ from airflow.providers.cohere.operators.embedding import
CohereEmbeddingOperator
@patch("airflow.providers.cohere.hooks.cohere.CohereHook.get_connection")
-@patch("cohere.Client")
+@patch("cohere.ClientV2")
def test_cohere_embedding_operator(cohere_client, get_connection):
"""
Test Cohere client is getting called with the correct key and that
@@ -35,22 +35,24 @@ def test_cohere_embedding_operator(cohere_client,
get_connection):
embeddings = embedded_obj
api_key = "test"
- api_url = "http://some_host.com"
+ base_url = "http://some_host.com"
timeout = 150
- max_retries = 5
texts = ["On Kernel-Target Alignment. We describe a family of global
optimization procedures"]
+ request_options = None
- get_connection.return_value = Connection(conn_type="cohere",
password=api_key, host=api_url)
+ get_connection.return_value = Connection(conn_type="cohere",
password=api_key, host=base_url)
client_obj = MagicMock()
cohere_client.return_value = client_obj
client_obj.embed.return_value = resp
op = CohereEmbeddingOperator(
- task_id="embed", conn_id="some_conn", input_text=texts,
timeout=timeout, max_retries=max_retries
+ task_id="embed",
+ conn_id="some_conn",
+ input_text=texts,
+ timeout=timeout,
+ request_options=request_options,
)
val = op.execute(context={})
- cohere_client.assert_called_once_with(
- api_key=api_key, api_url=api_url, timeout=timeout,
max_retries=max_retries
- )
+ cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url,
timeout=timeout)
assert val == embedded_obj