This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 390bec1c82 Add Yandex Query support from Yandex.Cloud (#37458)
390bec1c82 is described below
commit 390bec1c82f6a6ac0efdc51a1355b5aae79516cb
Author: uzhastik <[email protected]>
AuthorDate: Wed Mar 20 19:35:43 2024 +0300
Add Yandex Query support from Yandex.Cloud (#37458)
* initial commit
* support web link
* move jwt logic out of base hook
* use http client in hook
* add yq_results
* add token_prefix, format exception message
* use YQResults inside client
* add tests, fix provider.yaml
* fix oauth token usage, add tests for complex results
* add tests for YQ operator
* fix test name
* linting
* restyling
* improve tests, fix close(), add link to YQ service
* trim spaces
* add docstrings, remove query description, move privates to bottom of the
file
* fix last newline
* restyling
* restyling
* refactor, restyling
* revert version
* change text to trigger CI checks
* fixes for linters
* rework
* restyling
* fix CI tests, add yq link tests
* add doc strings
* fix link style tests
* rename files, add deps, fix doc string
* replace SQLExecuteQueryOperator with BaseOperator
* fix static checks
* fight with static checks
* remove http client, use py package
* fix static checks
---
airflow/providers/yandex/hooks/yandex.py | 4 +-
airflow/providers/yandex/hooks/yq.py | 132 ++++++++++++++++++++++++++++
airflow/providers/yandex/links/__init__.py | 16 ++++
airflow/providers/yandex/links/yq.py | 41 +++++++++
airflow/providers/yandex/operators/yq.py | 92 +++++++++++++++++++
airflow/providers/yandex/provider.yaml | 21 +++++
generated/provider_dependencies.json | 3 +
pyproject.toml | 3 +
tests/providers/yandex/hooks/test_yq.py | 117 ++++++++++++++++++++++++
tests/providers/yandex/links/__init__.py | 16 ++++
tests/providers/yandex/links/test_yq.py | 57 ++++++++++++
tests/providers/yandex/operators/test_yq.py | 114 ++++++++++++++++++++++++
12 files changed, 614 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/yandex/hooks/yandex.py
b/airflow/providers/yandex/hooks/yandex.py
index aa9cf4302e..251a47b7b8 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -132,13 +132,13 @@ class YandexCloudBaseHook(BaseHook):
self.connection_id = yandex_conn_id or connection_id or
default_conn_name
self.connection = self.get_connection(self.connection_id)
self.extras = self.connection.extra_dejson
- credentials = get_credentials(
+ self.credentials = get_credentials(
oauth_token=self._get_field("oauth"),
service_account_json=self._get_field("service_account_json"),
service_account_json_path=self._get_field("service_account_json_path"),
)
sdk_config = self._get_endpoint()
- self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(),
**sdk_config, **credentials)
+ self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(),
**sdk_config, **self.credentials)
self.default_folder_id = default_folder_id or
self._get_field("folder_id")
self.default_public_ssh_key = default_public_ssh_key or
self._get_field("public_ssh_key")
self.default_service_account_id = default_service_account_id or
get_service_account_id(
diff --git a/airflow/providers/yandex/hooks/yq.py
b/airflow/providers/yandex/hooks/yq.py
new file mode 100644
index 0000000000..963709d89b
--- /dev/null
+++ b/airflow/providers/yandex/hooks/yq.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import time
+from datetime import timedelta
+from typing import Any
+
+import jwt
+import requests
+from urllib3.util.retry import Retry
+from yandex_query_client import YQHttpClient, YQHttpClientConfig
+
+from airflow.exceptions import AirflowException
+from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
+from airflow.providers.yandex.utils.user_agent import provider_user_agent
+
+
+class YQHook(YandexCloudBaseHook):
+ """A hook for Yandex Query."""
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ config = YQHttpClientConfig(
+ token=self._get_iam_token(), project=self.default_folder_id,
user_agent=provider_user_agent()
+ )
+
+ self.client: YQHttpClient = YQHttpClient(config=config)
+
+ def close(self):
+ """Release all resources."""
+ self.client.close()
+
+ def create_query(self, query_text: str | None, name: str | None = None) ->
str:
+ """Create and run query.
+
+ :param query_text: SQL text.
+ :param name: name for the query
+ """
+ return self.client.create_query(
+ name=name,
+ query_text=query_text,
+ )
+
+ def wait_results(self, query_id: str, execution_timeout: timedelta =
timedelta(minutes=30)) -> Any:
+ """Wait for query complete and get results.
+
+ :param query_id: ID of query.
+ :param execution_timeout: how long to wait for the query to complete.
+ """
+ result_set_count = self.client.wait_query_to_succeed(
+ query_id, execution_timeout=execution_timeout, stop_on_timeout=True
+ )
+
+ return self.client.get_query_all_result_sets(query_id=query_id,
result_set_count=result_set_count)
+
+ def stop_query(self, query_id: str) -> None:
+ """Stop the query.
+
+ :param query_id: ID of the query.
+ """
+ self.client.stop_query(query_id)
+
+ def get_query(self, query_id: str) -> Any:
+ """Get query info.
+
+ :param query_id: ID of the query.
+ """
+ return self.client.get_query(query_id)
+
+ def get_query_status(self, query_id: str) -> str:
+ """Get status fo the query.
+
+ :param query_id: ID of query.
+ """
+ return self.client.get_query_status(query_id)
+
+ def compose_query_web_link(self, query_id: str):
+ """Compose web link to query in Yandex Query UI.
+
+ :param query_id: ID of query.
+ """
+ return self.client.compose_query_web_link(query_id)
+
+ def _get_iam_token(self) -> str:
+ if "token" in self.credentials:
+ return self.credentials["token"]
+ if "service_account_key" in self.credentials:
+ return
YQHook._resolve_service_account_key(self.credentials["service_account_key"])
+ raise AirflowException(f"Unknown credentials type, available keys
{self.credentials.keys()}")
+
+ @staticmethod
+ def _resolve_service_account_key(sa_info: dict) -> str:
+ with YQHook._create_session() as session:
+ api = "https://iam.api.cloud.yandex.net/iam/v1/tokens"
+ now = int(time.time())
+ payload = {"aud": api, "iss": sa_info["service_account_id"],
"iat": now, "exp": now + 360}
+
+ encoded_token = jwt.encode(
+ payload, sa_info["private_key"], algorithm="PS256",
headers={"kid": sa_info["id"]}
+ )
+
+ data = {"jwt": encoded_token}
+ iam_response = session.post(api, json=data)
+ iam_response.raise_for_status()
+
+ return iam_response.json()["iamToken"]
+
+ @staticmethod
+ def _create_session() -> requests.Session:
+ session = requests.Session()
+ session.verify = False
+ retry = Retry(backoff_factor=0.3, total=10)
+ session.mount("http://",
requests.adapters.HTTPAdapter(max_retries=retry))
+ session.mount("https://",
requests.adapters.HTTPAdapter(max_retries=retry))
+
+ return session
diff --git a/airflow/providers/yandex/links/__init__.py
b/airflow/providers/yandex/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/yandex/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/yandex/links/yq.py
b/airflow/providers/yandex/links/yq.py
new file mode 100644
index 0000000000..b168c5b0cf
--- /dev/null
+++ b/airflow/providers/yandex/links/yq.py
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.models import BaseOperatorLink, XCom
+
+if TYPE_CHECKING:
+ from airflow.models import BaseOperator
+ from airflow.models.taskinstancekey import TaskInstanceKey
+ from airflow.utils.context import Context
+
+XCOM_WEBLINK_KEY = "web_link"
+
+
+class YQLink(BaseOperatorLink):
+ """Web link to query in Yandex Query UI."""
+
+ name = "Yandex Query"
+
+ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
+ return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or
"https://yq.cloud.yandex.ru"
+
+ @staticmethod
+ def persist(context: Context, task_instance: BaseOperator, web_link: str)
-> None:
+ task_instance.xcom_push(context, key=XCOM_WEBLINK_KEY, value=web_link)
diff --git a/airflow/providers/yandex/operators/yq.py
b/airflow/providers/yandex/operators/yq.py
new file mode 100644
index 0000000000..52261edd31
--- /dev/null
+++ b/airflow/providers/yandex/operators/yq.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.models import BaseOperator
+from airflow.providers.yandex.hooks.yq import YQHook
+from airflow.providers.yandex.links.yq import YQLink
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class YQExecuteQueryOperator(BaseOperator):
+ """
+ Executes sql code using Yandex Query service.
+
+ :param sql: the SQL code to be executed as a single string
+ :param name: name of the query in YandexQuery
+ :param folder_id: cloud folder id where to create query
+ :param yandex_conn_id: Airflow connection ID to get parameters from
+ """
+
+ operator_extra_links = (YQLink(),)
+ template_fields: Sequence[str] = ("sql",)
+ template_fields_renderers = {"sql": "sql"}
+ template_ext: Sequence[str] = (".sql",)
+ ui_color = "#ededed"
+
+ def __init__(
+ self,
+ *,
+ name: str | None = None,
+ folder_id: str | None = None,
+ yandex_conn_id: str | None = None,
+ public_ssh_key: str | None = None,
+ service_account_id: str | None = None,
+ sql: str,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.name = name
+ self.folder_id = folder_id
+ self.yandex_conn_id = yandex_conn_id
+ self.public_ssh_key = public_ssh_key
+ self.service_account_id = service_account_id
+ self.sql = sql
+
+ self.query_id: str | None = None
+
+ @cached_property
+ def hook(self) -> YQHook:
+ """Get valid hook."""
+ return YQHook(
+ yandex_conn_id=self.yandex_conn_id,
+ default_folder_id=self.folder_id,
+ default_public_ssh_key=self.public_ssh_key,
+ default_service_account_id=self.service_account_id,
+ )
+
+ def execute(self, context: Context) -> Any:
+ self.query_id = self.hook.create_query(query_text=self.sql,
name=self.name)
+
+ # pass to YQLink
+ web_link = self.hook.compose_query_web_link(self.query_id)
+ YQLink.persist(context, self, web_link)
+
+ results = self.hook.wait_results(self.query_id)
+ # forget query to avoid 'stop_query' in on_kill
+ self.query_id = None
+ return results
+
+ def on_kill(self) -> None:
+ if self.hook is not None and self.query_id is not None:
+ self.hook.stop_query(self.query_id)
+ self.hook.close()
diff --git a/airflow/providers/yandex/provider.yaml
b/airflow/providers/yandex/provider.yaml
index 08c31f88d2..0135ac3fb4 100644
--- a/airflow/providers/yandex/provider.yaml
+++ b/airflow/providers/yandex/provider.yaml
@@ -49,6 +49,10 @@ versions:
dependencies:
- apache-airflow>=2.6.0
- yandexcloud>=0.228.0
+ - yandex-query-client>=0.1.2
+ - python-dateutil>=2.8.0
+ # Requests 3 if it will be released, will be heavily breaking.
+ - requests>=2.27.0,<3
integrations:
- integration-name: Yandex.Cloud
@@ -63,11 +67,22 @@ integrations:
logo: /integration-logos/yandex/Yandex-Cloud.png
tags: [service]
+ - integration-name: Yandex.Cloud YQ
+ external-doc-url: https://cloud.yandex.com/en/services/query
+ how-to-guide:
+ - /docs/apache-airflow-providers-yandex/operators.rst
+ logo: /integration-logos/yandex/Yandex-Cloud.png
+ tags: [service]
+
operators:
- integration-name: Yandex.Cloud Dataproc
python-modules:
- airflow.providers.yandex.operators.yandexcloud_dataproc
+ - integration-name: Yandex.Cloud YQ
+ python-modules:
+ - airflow.providers.yandex.operators.yq
+
hooks:
- integration-name: Yandex.Cloud
python-modules:
@@ -75,6 +90,9 @@ hooks:
- integration-name: Yandex.Cloud Dataproc
python-modules:
- airflow.providers.yandex.hooks.yandexcloud_dataproc
+ - integration-name: Yandex.Cloud YQ
+ python-modules:
+ - airflow.providers.yandex.hooks.yq
connection-types:
- hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook
@@ -83,6 +101,9 @@ connection-types:
secrets-backends:
- airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend
+extra-links:
+ - airflow.providers.yandex.links.yq.YQLink
+
config:
yandex:
description: This section contains settings for Yandex Cloud integration.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 87a6c051f1..66d19d2563 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1180,6 +1180,9 @@
"yandex": {
"deps": [
"apache-airflow>=2.6.0",
+ "python-dateutil>=2.8.0",
+ "requests>=2.27.0,<3",
+ "yandex-query-client>=0.1.2",
"yandexcloud>=0.228.0"
],
"devel-deps": [],
diff --git a/pyproject.toml b/pyproject.toml
index d8c63ba595..84cb37b91c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -975,6 +975,9 @@ weaviate = [ # source:
airflow/providers/weaviate/provider.yaml
"weaviate-client>=3.24.2",
]
yandex = [ # source: airflow/providers/yandex/provider.yaml
+ "python-dateutil>=2.8.0",
+ "requests>=2.27.0,<3",
+ "yandex-query-client>=0.1.2",
"yandexcloud>=0.228.0",
]
zendesk = [ # source: airflow/providers/zendesk/provider.yaml
diff --git a/tests/providers/yandex/hooks/test_yq.py
b/tests/providers/yandex/hooks/test_yq.py
new file mode 100644
index 0000000000..3b3db91dd1
--- /dev/null
+++ b/tests/providers/yandex/hooks/test_yq.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import timedelta
+from unittest import mock
+
+import responses
+from responses import matchers
+
+from airflow.models import Connection
+from airflow.providers.yandex.hooks.yq import YQHook
+
+OAUTH_TOKEN = "my_oauth_token"
+SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id",
"service_account_id":"my_sa1", "private_key":"my_pk"}"""
+
+
+class DummySDK:
+ def __init__(self) -> None:
+ self.client = None
+
+
+class TestYandexCloudYqHook:
+ def _init_hook(self):
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection") as
mock_get_connection:
+ mock_get_connection.return_value = self.connection
+ self.hook = YQHook(default_folder_id="my_folder_id")
+
+ def setup_method(self):
+ self.connection = Connection(extra={"service_account_json":
SERVICE_ACCOUNT_AUTH_KEY_JSON})
+
+ @responses.activate()
+ def test_oauth_token_usage(self):
+ responses.post(
+ "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries",
+ match=[
+ matchers.header_matcher(
+ {"Content-Type": "application/json", "Authorization":
f"Bearer {OAUTH_TOKEN}"}
+ ),
+ matchers.query_param_matcher({"project": "my_folder_id"}),
+ ],
+ json={"id": "query1"},
+ status=200,
+ )
+
+ self.connection = Connection(extra={"oauth": OAUTH_TOKEN})
+ self._init_hook()
+ query_id = self.hook.create_query(query_text="select 777", name="my
query")
+ assert query_id == "query1"
+
+ with
mock.patch("yandex_query_client.YQHttpClient.compose_query_web_link") as m:
+ m.return_value = "http://gg.zz"
+ assert self.hook.compose_query_web_link("query1") == "http://gg.zz"
+ m.assert_called_once_with("query1")
+
+ @responses.activate()
+ @mock.patch("yandexcloud.SDK")
+ @mock.patch("jwt.encode")
+ def test_select_results(self, mock_jwt, mock_sdk):
+ responses.post(
+ "https://iam.api.cloud.yandex.net/iam/v1/tokens",
+ json={"iamToken": "super_token"},
+ status=200,
+ )
+
+ mock_jwt.return_value = "zzzz"
+ mock_sdk.return_value = DummySDK()
+
+ with mock.patch.multiple(
+ "yandex_query_client.YQHttpClient",
+ create_query=mock.DEFAULT,
+ wait_query_to_succeed=mock.DEFAULT,
+ get_query_all_result_sets=mock.DEFAULT,
+ get_query_status=mock.DEFAULT,
+ get_query=mock.DEFAULT,
+ stop_query=mock.DEFAULT,
+ ) as mocks:
+ self._init_hook()
+ mocks["create_query"].return_value = "query1"
+ mocks["wait_query_to_succeed"].return_value = 2
+ mocks["get_query_all_result_sets"].return_value = {"x": 765}
+ mocks["get_query_status"].return_value = "COMPLETED"
+ mocks["get_query"].return_value = {"id": "my_q"}
+
+ query_id = self.hook.create_query(query_text="select 777",
name="my query")
+ assert query_id == "query1"
+ mocks["create_query"].assert_called_once_with(query_text="select
777", name="my query")
+
+ results = self.hook.wait_results(query_id,
execution_timeout=timedelta(minutes=10))
+ assert results == {"x": 765}
+ mocks["wait_query_to_succeed"].assert_called_once_with(
+ query_id, execution_timeout=timedelta(minutes=10),
stop_on_timeout=True
+ )
+
mocks["get_query_all_result_sets"].assert_called_once_with(query_id=query_id,
result_set_count=2)
+
+ assert self.hook.get_query_status(query_id) == "COMPLETED"
+ mocks["get_query_status"].assert_called_once_with(query_id)
+
+ assert self.hook.get_query(query_id) == {"id": "my_q"}
+ mocks["get_query"].assert_called_once_with(query_id)
+
+ self.hook.stop_query(query_id)
+ mocks["stop_query"].assert_called_once_with(query_id)
diff --git a/tests/providers/yandex/links/__init__.py
b/tests/providers/yandex/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/yandex/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/yandex/links/test_yq.py
b/tests/providers/yandex/links/test_yq.py
new file mode 100644
index 0000000000..82113fa44e
--- /dev/null
+++ b/tests/providers/yandex/links/test_yq.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import XCom
+from airflow.providers.yandex.links.yq import YQLink
+from tests.test_utils.mock_operators import MockOperator
+
+
+def test_persist():
+ mock_context = mock.MagicMock()
+
+ YQLink.persist(context=mock_context,
task_instance=MockOperator(task_id="test_task_id"), web_link="g.com")
+
+ ti = mock_context["ti"]
+ ti.xcom_push.assert_called_once_with(
+ execution_date=None,
+ key="web_link",
+ value="g.com",
+ )
+
+
+def test_default_link():
+ with mock.patch.object(XCom, "get_value") as m:
+ m.return_value = None
+ link = YQLink()
+
+ op = MockOperator(task_id="test_task_id")
+ ti = TaskInstance(task=op, run_id="run_id1")
+ assert link.get_link(op, ti_key=ti.key) == "https://yq.cloud.yandex.ru"
+
+
+def test_link():
+ with mock.patch.object(XCom, "get_value") as m:
+ m.return_value = "https://g.com"
+ link = YQLink()
+
+ op = MockOperator(task_id="test_task_id")
+ ti = TaskInstance(task=op, run_id="run_id1")
+ assert link.get_link(op, ti_key=ti.key) == "https://g.com"
diff --git a/tests/providers/yandex/operators/test_yq.py
b/tests/providers/yandex/operators/test_yq.py
new file mode 100644
index 0000000000..040f4089b4
--- /dev/null
+++ b/tests/providers/yandex/operators/test_yq.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import re
+from datetime import datetime, timedelta
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+import responses
+from responses import matchers
+
+from airflow.models import Connection
+from airflow.models.dag import DAG
+from airflow.providers.yandex.operators.yq import YQExecuteQueryOperator
+
+OAUTH_TOKEN = "my_oauth_token"
+FOLDER_ID = "my_folder_id"
+
+
+class TestYQExecuteQueryOperator:
+ def setup_method(self):
+ dag_id = "test_dag"
+ self.dag = DAG(
+ dag_id,
+ default_args={
+ "owner": "airflow",
+ "start_date": datetime.today(),
+ "end_date": datetime.today() + timedelta(days=1),
+ },
+ schedule="@once",
+ )
+
+ @responses.activate()
+ @patch("airflow.hooks.base.BaseHook.get_connection")
+ def test_execute_query(self, mock_get_connection):
+ mock_get_connection.return_value = Connection(extra={"oauth":
OAUTH_TOKEN})
+ operator = YQExecuteQueryOperator(task_id="simple_sql", sql="select
987", folder_id="my_folder_id")
+ context = {"ti": MagicMock()}
+
+ responses.post(
+ "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries",
+ match=[
+ matchers.header_matcher(
+ {"Content-Type": "application/json", "Authorization":
f"Bearer {OAUTH_TOKEN}"}
+ ),
+ matchers.query_param_matcher({"project": FOLDER_ID}),
+ matchers.json_params_matcher({"text": "select 987"}),
+ ],
+ json={"id": "query1"},
+ status=200,
+ )
+
+ responses.get(
+
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status",
+ json={"status": "COMPLETED"},
+ status=200,
+ )
+
+ responses.get(
+
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1",
+ json={"id": "query1", "result_sets": [{"rows_count": 1,
"truncated": False}]},
+ status=200,
+ )
+
+ responses.get(
+
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0",
+ json={"rows": [[777]], "columns": [{"name": "column0", "type":
"Int32"}]},
+ status=200,
+ )
+
+ results = operator.execute(context)
+ assert results == {"rows": [[777]], "columns": [{"name": "column0",
"type": "Int32"}]}
+
+ context["ti"].xcom_push.assert_has_calls(
+ [
+ call(
+ key="web_link",
+
value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1",
+ execution_date=None,
+ ),
+ ]
+ )
+
+ responses.get(
+
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status",
+ json={"status": "ERROR"},
+ status=200,
+ )
+
+ responses.get(
+
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1",
+ json={"id": "query1", "issues": ["some error"]},
+ status=200,
+ )
+
+ with pytest.raises(
+ RuntimeError, match=re.escape("""Query query1 failed with
issues=['some error']""")
+ ):
+ operator.execute(context)