This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 739e6b5d77 Add SnowflakeSqlApiOperator operator (#30698)
739e6b5d77 is described below
commit 739e6b5d775412f987a3ff5fb71c51fbb7051a89
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Wed May 24 12:40:03 2023 +0530
Add SnowflakeSqlApiOperator operator (#30698)
* Add SnowflakeSqlApiOperatorAsyn
* Remove unrelated change
* Saving work
* Add SnowflakeSqlApiOperator operator
* Remove snowflake_trigger
* Remove snowflake_trigger
* Remove unwanted code
* Remove unwanted code
* Remove unwanted code
* Amend test cases
* Remove unwanted code
* Fix static checks
* Fix docs
* Update airflow/providers/snowflake/hooks/snowflake_sql_api.py
* Remove unwanted code
* Remove unwanted code
* Fix static check
* Add docs and example
* Fix static check
---
.../providers/snowflake/hooks/snowflake_sql_api.py | 256 ++++++++++++
.../snowflake/hooks/sql_api_generate_jwt.py | 174 ++++++++
airflow/providers/snowflake/operators/snowflake.py | 167 +++++++-
airflow/providers/snowflake/provider.yaml | 2 +
.../operators/snowflake.rst | 43 ++
.../snowflake/hooks/test_snowflake_sql_api.py | 458 +++++++++++++++++++++
.../snowflake/hooks/test_sql_api_generate_jwt.py | 54 +++
.../snowflake/operators/test_snowflake.py | 69 ++++
.../providers/snowflake/example_snowflake.py | 11 +-
9 files changed, 1232 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
new file mode 100644
index 0000000000..e77b7607b7
--- /dev/null
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -0,0 +1,256 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import uuid
+from datetime import timedelta
+from pathlib import Path
+from typing import Any
+
+import requests
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+
+from airflow import AirflowException
+from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
+from airflow.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator
+
+
+class SnowflakeSqlApiHook(SnowflakeHook):
+ """
+ A client to interact with Snowflake using SQL API and allows submitting
+ multiple SQL statements in a single request. In combination with aiohttp,
make post request to submit SQL
+ statements for execution, poll to check the status of the execution of a
statement. Fetch query results
+ asynchronously.
+
+ This hook requires the snowflake_conn_id connection. This hooks mainly
uses account, schema, database,
+ warehouse, private_key_file or private_key_content field must be setup in
the connection. Other inputs
+ can be defined in the connection or hook instantiation.
+
+ :param snowflake_conn_id: Reference to
+ :ref:`Snowflake connection id<howto/connection:snowflake>`
+ :param account: snowflake account name
+ :param authenticator: authenticator for Snowflake.
+ 'snowflake' (default) to use the internal Snowflake authenticator
+ 'externalbrowser' to authenticate using your web browser and
+ Okta, ADFS or any other SAML 2.0-compliant identify provider
+ (IdP) that has been defined for your account
+ 'https://<your_okta_account_name>.okta.com' to authenticate
+ through native Okta.
+ :param warehouse: name of snowflake warehouse
+ :param database: name of snowflake database
+ :param region: name of snowflake region
+ :param role: name of snowflake role
+ :param schema: name of snowflake schema
+ :param session_parameters: You can set session-level parameters at
+ the time you connect to Snowflake
+ :param token_life_time: lifetime of the JWT Token in timedelta
+ :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+ """
+
+ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute
lifetime
+ RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54
minutes
+
+ def __init__(
+ self,
+ snowflake_conn_id: str,
+ token_life_time: timedelta = LIFETIME,
+ token_renewal_delta: timedelta = RENEWAL_DELTA,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ self.snowflake_conn_id = snowflake_conn_id
+ self.token_life_time = token_life_time
+ self.token_renewal_delta = token_renewal_delta
+ super().__init__(snowflake_conn_id, *args, **kwargs)
+ self.private_key: Any = None
+
+ def get_private_key(self) -> None:
+ """Gets the private key from snowflake connection"""
+ conn = self.get_connection(self.snowflake_conn_id)
+
+ # If private_key_file is specified in the extra json, load the
contents of the file as a private key.
+ # If private_key_content is specified in the extra json, use it as a
private key.
+ # As a next step, specify this private key in the connection
configuration.
+ # The connection password then becomes the passphrase for the private
key.
+ # If your private key is not encrypted (not recommended), then leave
the password empty.
+
+ private_key_file = conn.extra_dejson.get(
+ "extra__snowflake__private_key_file"
+ ) or conn.extra_dejson.get("private_key_file")
+ private_key_content = conn.extra_dejson.get(
+ "extra__snowflake__private_key_content"
+ ) or conn.extra_dejson.get("private_key_content")
+
+ private_key_pem = None
+ if private_key_content and private_key_file:
+ raise AirflowException(
+ "The private_key_file and private_key_content extra fields are
mutually exclusive. "
+ "Please remove one."
+ )
+ elif private_key_file:
+ private_key_pem = Path(private_key_file).read_bytes()
+ elif private_key_content:
+ private_key_pem = private_key_content.encode()
+
+ if private_key_pem:
+ passphrase = None
+ if conn.password:
+ passphrase = conn.password.strip().encode()
+
+ self.private_key = serialization.load_pem_private_key(
+ private_key_pem, password=passphrase, backend=default_backend()
+ )
+
+ def execute_query(
+ self, sql: str, statement_count: int, query_tag: str = "", bindings:
dict[str, Any] | None = None
+ ) -> list[str]:
+ """
+ Using SnowflakeSQL API, run the query in snowflake by making API
request
+
+ :param sql: the sql string to be executed with possibly multiple
statements
+ :param statement_count: set the MULTI_STATEMENT_COUNT field to the
number of SQL statements
+ in the request
+ :param query_tag: (Optional) Query tag that you want to associate with
the SQL statement.
+ For details, see
https://docs.snowflake.com/en/sql-reference/parameters.html#label-query-tag
+ parameter.
+ :param bindings: (Optional) Values of bind variables in the SQL
statement.
+ When executing the statement, Snowflake replaces placeholders (?
and :name) in
+ the statement with these specified values.
+ """
+ conn_config = self._get_conn_params()
+
+ req_id = uuid.uuid4()
+ url =
f"https://{conn_config['account']}.{conn_config['region']}.snowflakecomputing.com/api/v2/statements"
+ params: dict[str, Any] | None = {"requestId": str(req_id), "async":
True, "pageSize": 10}
+ headers = self.get_headers()
+ if bindings is None:
+ bindings = {}
+ data = {
+ "statement": sql,
+ "resultSetMetaData": {"format": "json"},
+ "database": conn_config["database"],
+ "schema": conn_config["schema"],
+ "warehouse": conn_config["warehouse"],
+ "role": conn_config["role"],
+ "bindings": bindings,
+ "parameters": {
+ "MULTI_STATEMENT_COUNT": statement_count,
+ "query_tag": query_tag,
+ },
+ }
+ response = requests.post(url, json=data, headers=headers,
params=params)
+ try:
+ response.raise_for_status()
+ except requests.exceptions.HTTPError as e: # pragma: no cover
+ raise AirflowException(f"Response: {e.response.content} Status
Code: {e.response.status_code}")
+ json_response = response.json()
+ self.log.info("Snowflake SQL POST API response: %s", json_response)
+ if "statementHandles" in json_response:
+ self.query_ids = json_response["statementHandles"]
+ elif "statementHandle" in json_response:
+ self.query_ids.append(json_response["statementHandle"])
+ else:
+ raise AirflowException("No statementHandle/statementHandles
present in response")
+ return self.query_ids
+
+ def get_headers(self) -> dict[str, Any]:
+ """Based on the private key, and with connection details JWT Token is
generated and header
+ is formed
+ """
+ if not self.private_key:
+ self.get_private_key()
+ conn_config = self._get_conn_params()
+
+ # Get the JWT token from the connection details and the private key
+ token = JWTGenerator(
+ conn_config["account"], # type: ignore[arg-type]
+ conn_config["user"], # type: ignore[arg-type]
+ private_key=self.private_key,
+ lifetime=self.token_life_time,
+ renewal_delay=self.token_renewal_delta,
+ ).get_token()
+
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {token}",
+ "Accept": "application/json",
+ "User-Agent": "snowflakeSQLAPI/1.0",
+ "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+ }
+ return headers
+
+ def get_request_url_header_params(self, query_id: str) -> tuple[dict[str,
Any], dict[str, Any], str]:
+ """
+ Build the request header Url with account name identifier and query id
from the connection params
+
+ :param query_id: statement handles query ids for the individual
statements.
+ """
+ conn_config = self._get_conn_params()
+ req_id = uuid.uuid4()
+ header = self.get_headers()
+ params = {"requestId": str(req_id)}
+ url =
f"https://{conn_config['account']}.{conn_config['region']}.snowflakecomputing.com/api/v2/statements/{query_id}"
+ return header, params, url
+
+ def check_query_output(self, query_ids: list[str]) -> None:
+ """
+ Based on the query ids passed as the parameter make HTTP request to
snowflake SQL API and logs
+ the response
+
+ :param query_ids: statement handles query id for the individual
statements.
+ """
+ for query_id in query_ids:
+ header, params, url = self.get_request_url_header_params(query_id)
+ try:
+ response = requests.get(url, headers=header, params=params)
+ response.raise_for_status()
+ self.log.info(response.json())
+ except requests.exceptions.HTTPError as e:
+ raise AirflowException(
+ f"Response: {e.response.content}, Status Code:
{e.response.status_code}"
+ )
+
+ def get_sql_api_query_status(self, query_id: str) -> dict[str, str |
list[str]]:
+ """
+ Based on the query id async HTTP request is made to snowflake SQL API
and return response.
+
+ :param query_id: statement handle id for the individual statements.
+ """
+ self.log.info("Retrieving status for query id %s", {query_id})
+ header, params, url = self.get_request_url_header_params(query_id)
+ response = requests.get(url, params=params, headers=header)
+ status_code = response.status_code
+ resp = response.json()
+ self.log.info("Snowflake SQL GET statements status API response: %s",
resp)
+ if status_code == 202:
+ return {"status": "running", "message": "Query statements are
still running"}
+ elif status_code == 422:
+ return {"status": "error", "message": resp["message"]}
+ elif status_code == 200:
+ statement_handles = []
+ if "statementHandles" in resp and resp["statementHandles"]:
+ statement_handles = resp["statementHandles"]
+ elif "statementHandle" in resp and resp["statementHandle"]:
+ statement_handles.append(resp["statementHandle"])
+ return {
+ "status": "success",
+ "message": resp["message"],
+ "statement_handles": statement_handles,
+ }
+ else:
+ return {"status": "error", "message": resp["message"]}
diff --git a/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py
b/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py
new file mode 100644
index 0000000000..883ef3aae0
--- /dev/null
+++ b/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import base64
+import hashlib
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+# This class relies on the PyJWT module (https://pypi.org/project/PyJWT/).
+import jwt
+from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
+
+logger = logging.getLogger(__name__)
+
+
+ISSUER = "iss"
+EXPIRE_TIME = "exp"
+ISSUE_TIME = "iat"
+SUBJECT = "sub"
+
+# If you generated an encrypted private key, implement this method to return
+# the passphrase for decrypting your private key. As an example, this function
+# prompts the user for the passphrase.
+
+
+class JWTGenerator:
+ """
+ Creates and signs a JWT with the specified private key file, username, and
account identifier.
+ The JWTGenerator keeps the generated token and only regenerates the token
if a specified period
+ of time has passed.
+
+ Creates an object that generates JWTs for the specified user, account
identifier, and private key
+
+ :param account: Your Snowflake account identifier.
+ See
https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note
that if you are using
+ the account locator, exclude any region information from the account
locator.
+ :param user: The Snowflake username.
+ :param private_key: Private key from the file path for signing the JWTs.
+ :param lifetime: The number of minutes (as a timedelta) during which the
key will be valid.
+ :param renewal_delay: The number of minutes (as a timedelta) from now
after which the JWT
+ generator should renew the JWT.
+ """
+
+ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute
lifetime
+ RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54
minutes
+ ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256
+
+ def __init__(
+ self,
+ account: str,
+ user: str,
+ private_key: Any,
+ lifetime: timedelta = LIFETIME,
+ renewal_delay: timedelta = RENEWAL_DELTA,
+ ):
+ logger.info(
+ """Creating JWTGenerator with arguments
+ account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
+ account,
+ user,
+ lifetime,
+ renewal_delay,
+ )
+
+ # Construct the fully qualified name of the user in uppercase.
+ self.account = self.prepare_account_name_for_jwt(account)
+ self.user = user.upper()
+ self.qualified_username = self.account + "." + self.user
+
+ self.lifetime = lifetime
+ self.renewal_delay = renewal_delay
+ self.private_key = private_key
+ self.renew_time = datetime.now(timezone.utc)
+ self.token: str | None = None
+
+ def prepare_account_name_for_jwt(self, raw_account: str) -> str:
+ """
+ Prepare the account identifier for use in the JWT.
+ For the JWT, the account identifier must not include the subdomain or
any region or cloud provider
+ information.
+
+ :param raw_account: The specified account identifier.
+ """
+ account = raw_account
+ if ".global" not in account:
+ # Handle the general case.
+ idx = account.find(".")
+ if idx > 0:
+ account = account[0:idx]
+ else:
+ # Handle the replication case.
+ idx = account.find("-")
+ if idx > 0:
+ account = account[0:idx] # pragma: no cover
+ # Use uppercase for the account identifier.
+ return account.upper()
+
+ def get_token(self) -> str | None:
+ """
+ Generates a new JWT. If a JWT has been already been generated earlier,
return the previously
+ generated token unless the specified renewal time has passed.
+ """
+ now = datetime.now(timezone.utc) # Fetch the current time
+
+ # If the token has expired or doesn't exist, regenerate the token.
+ if self.token is None or self.renew_time <= now:
+ logger.info(
+ "Generating a new token because the present time (%s) is later
than the renewal time (%s)",
+ now,
+ self.renew_time,
+ )
+ # Calculate the next time we need to renew the token.
+ self.renew_time = now + self.renewal_delay
+
+ # Prepare the fields for the payload.
+ # Generate the public key fingerprint for the issuer in the
payload.
+ public_key_fp =
self.calculate_public_key_fingerprint(self.private_key)
+
+ # Create our payload
+ payload = {
+ # Set the issuer to the fully qualified username concatenated
with the public key fingerprint.
+ ISSUER: self.qualified_username + "." + public_key_fp,
+ # Set the subject to the fully qualified username.
+ SUBJECT: self.qualified_username,
+ # Set the issue time to now.
+ ISSUE_TIME: now,
+ # Set the expiration time, based on the lifetime specified for
this object.
+ EXPIRE_TIME: now + self.lifetime,
+ }
+
+ # Regenerate the actual token
+ token = jwt.encode(payload, key=self.private_key,
algorithm=JWTGenerator.ALGORITHM)
+
+ if isinstance(token, bytes):
+ token = token.decode("utf-8")
+ self.token = token
+
+ return self.token
+
+ def calculate_public_key_fingerprint(self, private_key: Any) -> str:
+ """
+ Given a private key in PEM format, return the public key fingerprint.
+
+ :param private_key: private key
+ """
+ # Get the raw bytes of public key.
+ public_key_raw = private_key.public_key().public_bytes(
+ Encoding.DER, PublicFormat.SubjectPublicKeyInfo
+ )
+
+ # Get the sha256 hash of the raw bytes.
+ sha256hash = hashlib.sha256()
+ sha256hash.update(public_key_raw)
+
+ # Base64-encode the value and prepend the prefix 'SHA256:'.
+ public_key_fp = "SHA256:" +
base64.b64encode(sha256hash.digest()).decode("utf-8")
+
+ return public_key_fp
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index f09039baaf..2257ec759e 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -17,9 +17,12 @@
# under the License.
from __future__ import annotations
+import time
import warnings
-from typing import Any, Iterable, Mapping, Sequence, SupportsAbs
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
+from airflow import AirflowException
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.operators.sql import (
SQLCheckOperator,
@@ -27,6 +30,12 @@ from airflow.providers.common.sql.operators.sql import (
SQLIntervalCheckOperator,
SQLValueCheckOperator,
)
+from airflow.providers.snowflake.hooks.snowflake_sql_api import (
+ SnowflakeSqlApiHook,
+)
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
class SnowflakeOperator(SQLExecuteQueryOperator):
@@ -359,3 +368,159 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
self.authenticator = authenticator
self.session_parameters = session_parameters
self.query_ids: list[str] = []
+
+
+class SnowflakeSqlApiOperator(SnowflakeOperator):
+ """
+ Implemented Snowflake SQL API Operator to support multiple SQL statements
sequentially,
+ which is the behavior of the SnowflakeOperator, the Snowflake SQL API
allows submitting
+ multiple SQL statements in a single request. It make post request to
submit SQL
+ statements for execution, poll to check the status of the execution of a
statement. Fetch query results
+ concurrently.
+ This Operator currently uses key pair authentication, so you need to
provide private key raw content or
+ private key file path in the snowflake connection along with other details
+
+ .. seealso::
+
+ `Snowflake SQL API key pair Authentication
<https://docs.snowflake.com/en/developer-guide/sql-api/authenticating.html#label-sql-api-authenticating-key-pair>`_
+
+ Where can this operator fit in?
+ - To execute multiple SQL statements in a single request
+ - To execute the SQL statement asynchronously and to execute standard
queries and most DDL and DML statements
+ - To develop custom applications and integrations that perform queries
+ - To create provision users and roles, create table, etc.
+
+ The following commands are not supported:
+ - The PUT command (in Snowflake SQL)
+ - The GET command (in Snowflake SQL)
+ - The CALL command with stored procedures that return a table(stored
procedures with the RETURNS TABLE clause).
+
+ .. seealso::
+
+ - `Snowflake SQL API
<https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#introduction-to-the-sql-api>`_
+ - `API Reference
<https://docs.snowflake.com/en/developer-guide/sql-api/reference.html#snowflake-sql-api-reference>`_
+ - `Limitation on snowflake SQL API
<https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#limitations-of-the-sql-api>`_
+
+ :param snowflake_conn_id: Reference to Snowflake connection id
+ :param sql: the sql code to be executed. (templated)
+ :param autocommit: if True, each command is automatically committed.
+ (default value: True)
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :param warehouse: name of warehouse (will overwrite any warehouse
+ defined in the connection's extra JSON)
+ :param database: name of database (will overwrite database defined
+ in connection)
+ :param schema: name of schema (will overwrite schema defined in
+ connection)
+ :param role: name of role (will overwrite any role defined in
+ connection's extra JSON)
+ :param authenticator: authenticator for Snowflake.
+ 'snowflake' (default) to use the internal Snowflake authenticator
+ 'externalbrowser' to authenticate using your web browser and
+ Okta, ADFS or any other SAML 2.0-compliant identify provider
+ (IdP) that has been defined for your account
+ 'https://<your_okta_account_name>.okta.com' to authenticate
+ through native Okta.
+ :param session_parameters: You can set session-level parameters at
+ the time you connect to Snowflake
+ :param poll_interval: the interval in seconds to poll the query
+ :param statement_count: Number of SQL statement to be executed
+ :param token_life_time: lifetime of the JWT Token
+ :param token_renewal_delta: Renewal time of the JWT Token
+ :param bindings: (Optional) Values of bind variables in the SQL statement.
+ When executing the statement, Snowflake replaces placeholders (?
and :name) in
+ the statement with these specified values.
+ """ # noqa
+
+ LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes
lifetime
+ RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54
minutes
+
+ def __init__(
+ self,
+ *,
+ snowflake_conn_id: str = "snowflake_default",
+ warehouse: str | None = None,
+ database: str | None = None,
+ role: str | None = None,
+ schema: str | None = None,
+ authenticator: str | None = None,
+ session_parameters: dict[str, Any] | None = None,
+ poll_interval: int = 5,
+ statement_count: int = 0,
+ token_life_time: timedelta = LIFETIME,
+ token_renewal_delta: timedelta = RENEWAL_DELTA,
+ bindings: dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ self.snowflake_conn_id = snowflake_conn_id
+ self.poll_interval = poll_interval
+ self.statement_count = statement_count
+ self.token_life_time = token_life_time
+ self.token_renewal_delta = token_renewal_delta
+ self.bindings = bindings
+ self.execute_async = False
+ if self.__class__.__base__.__name__ != "SnowflakeOperator":
+ # It's better to do str check of the parent class name because
currently SnowflakeOperator
+ # is deprecated and in future OSS SnowflakeOperator may be removed
+ if any(
+ [warehouse, database, role, schema, authenticator,
session_parameters]
+ ): # pragma: no cover
+ hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
+ kwargs["hook_params"] = {
+ "warehouse": warehouse,
+ "database": database,
+ "role": role,
+ "schema": schema,
+ "authenticator": authenticator,
+ "session_parameters": session_parameters,
+ **hook_params,
+ }
+ super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma:
no cover
+ else:
+ super().__init__(**kwargs)
+
+ def execute(self, context: Context) -> None:
+ """
+ Make a POST API request to snowflake by using SnowflakeSQL and execute
the query to get the ids.
+ By deferring the SnowflakeSqlApiTrigger class passed along with query
ids.
+ """
+ self.log.info("Executing: %s", self.sql)
+ self._hook = SnowflakeSqlApiHook(
+ snowflake_conn_id=self.snowflake_conn_id,
+ token_life_time=self.token_life_time,
+ token_renewal_delta=self.token_renewal_delta,
+ )
+ self.query_ids = self._hook.execute_query(
+ self.sql, statement_count=self.statement_count,
bindings=self.bindings # type: ignore[arg-type]
+ )
+ self.log.info("List of query ids %s", self.query_ids)
+
+ if self.do_xcom_push:
+ context["ti"].xcom_push(key="query_ids", value=self.query_ids)
+
+ statement_status = self.poll_on_queries()
+ if statement_status["error"]:
+ raise AirflowException(statement_status["error"])
+ self._hook.check_query_output(self.query_ids)
+
+ def poll_on_queries(self):
+ """Poll on requested queries"""
+ queries_in_progress = set(self.query_ids)
+ statement_success_status = {}
+ statement_error_status = {}
+ for query_id in self.query_ids:
+ if not len(queries_in_progress):
+ break
+ self.log.info("checking : %s", query_id)
+ try:
+ statement_status =
self._hook.get_sql_api_query_status(query_id)
+ except Exception as e:
+ raise ValueError({"status": "error", "message": str(e)})
+ if statement_status.get("status") == "error":
+ queries_in_progress.remove(query_id)
+ statement_error_status[query_id] = statement_status
+ if statement_status.get("status") == "success":
+ statement_success_status[query_id] = statement_status
+ queries_in_progress.remove(query_id)
+ time.sleep(self.poll_interval)
+ return {"success": statement_success_status, "error":
statement_error_status}
diff --git a/airflow/providers/snowflake/provider.yaml
b/airflow/providers/snowflake/provider.yaml
index 62af89dfc3..5e4d04a3f5 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -75,6 +75,8 @@ hooks:
- integration-name: Snowflake
python-modules:
- airflow.providers.snowflake.hooks.snowflake
+ - airflow.providers.snowflake.hooks.snowflake_sql_api
+ - airflow.providers.snowflake.hooks.sql_api_generate_jwt
transfers:
- source-integration-name: Amazon Simple Storage Service (S3)
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
index f823430fa8..1e80f3af29 100644
--- a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+++ b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
@@ -58,3 +58,46 @@ An example usage of the SnowflakeOperator is as follows:
Parameters that can be passed onto the operator will be given priority over
the parameters already given
in the Airflow connection metadata (such as ``schema``, ``role``,
``database`` and so forth).
+
+
+SnowflakeSqlApiOperator
+=======================
+
+Use the :class:`SnowflakeSqlApiHook
<airflow.providers.snowflake.operators.snowflake>` to execute
+SQL commands in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to connect to your Snowflake instance
where
+the connection metadata is structured as follows:
+
+.. list-table:: Snowflake Airflow Connection Metadata
+ :widths: 25 25
+ :header-rows: 1
+
+ * - Parameter
+ - Input
+ * - Login: string
+ - Snowflake user name
+ * - Password: string
+ - Password for Snowflake user
+ * - Schema: string
+ - Set schema to execute SQL operations on by default
+ * - Extra: dictionary
+ - ``warehouse``, ``account``, ``database``, ``region``, ``role``,
``authenticator``
+
+An example usage of the SnowflakeSqlApiHook is as follows:
+
+.. exampleinclude::
/../../tests/system/providers/snowflake/example_snowflake.py
+ :language: python
+ :start-after: [START howto_snowflake_sql_api_operator]
+ :end-before: [END howto_snowflake_sql_api_operator]
+ :dedent: 4
+
+
+.. note::
+
+ Parameters that can be passed onto the operator will be given priority over
the parameters already given
+ in the Airflow connection metadata (such as ``schema``, ``role``,
``database`` and so forth).
diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
new file mode 100644
index 0000000000..8f68f446c8
--- /dev/null
+++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
@@ -0,0 +1,458 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import unittest
+import uuid
+from pathlib import Path
+from typing import Any
+from unittest import mock
+
+import pytest
+import requests
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+from airflow import AirflowException
+from airflow.models import Connection
+from airflow.providers.snowflake.hooks.snowflake_sql_api import (
+ SnowflakeSqlApiHook,
+)
+
+SQL_MULTIPLE_STMTS = (
+ "create or replace table user_test (i int); insert into user_test (i) "
+ "values (200); insert into user_test (i) values (300); select i from
user_test order by i;"
+)
+_PASSWORD = "snowflake42"
+
+SINGLE_STMT = "select i from user_test order by i;"
+BASE_CONNECTION_KWARGS: dict = {
+ "login": "user",
+ "conn_type": "snowflake",
+ "password": "pw",
+ "schema": "public",
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ },
+}
+CONN_PARAMS = {
+ "account": "airflow",
+ "application": "AIRFLOW",
+ "authenticator": "snowflake",
+ "database": "db",
+ "password": "pw",
+ "region": "af_region",
+ "role": "af_role",
+ "schema": "public",
+ "session_parameters": None,
+ "user": "user",
+ "warehouse": "af_wh",
+}
+HEADERS = {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer newT0k3n",
+ "Accept": "application/json",
+ "User-Agent": "snowflakeSQLAPI/1.0",
+ "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+}
+
+GET_RESPONSE = {
+ "resultSetMetaData": {
+ "numRows": 10000,
+ "format": "jsonv2",
+ "rowType": [
+ {
+ "name": "SEQ8()",
+ "database": "",
+ "schema": "",
+ "table": "",
+ "scale": 0,
+ "precision": 19,
+ },
+ {
+ "name": "RANDSTR(1000, RANDOM())",
+ "database": "",
+ "schema": "",
+ "table": "",
+ },
+ ],
+ "partitionInfo": [
+ {
+ "rowCount": 12344,
+ "uncompressedSize": 14384873,
+ },
+ {"rowCount": 43746, "uncompressedSize": 43748274,
"compressedSize": 746323},
+ ],
+ },
+ "code": "090001",
+ "statementStatusUrl":
"/api/v2/statements/{handle}?requestId={id5}&partition=10",
+ "sqlState": "00000",
+ "statementHandle": "{handle}",
+ "message": "Statement executed successfully.",
+ "createdOn": 1620151693299,
+}
+
+
+def create_successful_response_mock(content):
+ """Create mock response for success state"""
+ response = mock.MagicMock()
+ response.json.return_value = content
+ response.status_code = 200
+ return response
+
+
+def create_post_side_effect(status_code=429):
+ """create mock response for post side effect"""
+ response = mock.MagicMock()
+ response.status_code = status_code
+ response.reason = "test"
+ response.raise_for_status.side_effect =
requests.exceptions.HTTPError(response=response)
+ return response
+
+
+class TestSnowflakeSqlApiHook:
+ @pytest.mark.parametrize(
+ "sql,statement_count,expected_response, expected_query_ids",
+ [
+ (SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"]),
+ (SQL_MULTIPLE_STMTS, 4, {"statementHandles": ["uuid", "uuid1"]},
["uuid", "uuid1"]),
+ ],
+ )
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+ def test_execute_query(
+ self,
+ mock_get_header,
+ mock_conn_param,
+ mock_requests,
+ sql,
+ statement_count,
+ expected_response,
+ expected_query_ids,
+ ):
+ """Test execute_query method, run query by mocking post request method
and return the query ids"""
+ mock_requests.codes.ok = 200
+ mock_requests.post.side_effect = [
+ create_successful_response_mock(expected_response),
+ ]
+ status_code_mock = mock.PropertyMock(return_value=200)
+ type(mock_requests.post.return_value).status_code = status_code_mock
+
+ hook = SnowflakeSqlApiHook("mock_conn_id")
+ query_ids = hook.execute_query(sql, statement_count)
+ assert query_ids == expected_query_ids
+
+ @pytest.mark.parametrize(
+ "sql,statement_count,expected_response, expected_query_ids",
+ [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
+ )
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+ def test_execute_query_exception_without_statement_handel(
+ self,
+ mock_get_header,
+ mock_conn_param,
+ mock_requests,
+ sql,
+ statement_count,
+ expected_response,
+ expected_query_ids,
+ ):
+ """
+ Test execute_query method by mocking the exception response and raise
airflow exception
+ without statementHandle in the response
+ """
+ side_effect = create_post_side_effect()
+ mock_requests.post.side_effect = side_effect
+ hook = SnowflakeSqlApiHook("mock_conn_id")
+
+ with pytest.raises(AirflowException) as exception_info:
+ hook.execute_query(sql, statement_count)
+ assert exception_info
+
+ @pytest.mark.parametrize(
+ "query_ids",
+ [
+ (["uuid", "uuid1"]),
+ ],
+ )
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+ "get_request_url_header_params"
+ )
+ def test_check_query_output(self, mock_geturl_header_params,
mock_requests, query_ids):
+ """Test check_query_output by passing query ids as params and mock
get_request_url_header_params"""
+ req_id = uuid.uuid4()
+ params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+ mock_geturl_header_params.return_value = HEADERS, params,
"/test/airflow/"
+ mock_requests.get.return_value.json.return_value = GET_RESPONSE
+ hook = SnowflakeSqlApiHook("mock_conn_id")
+ with mock.patch.object(hook.log, "info") as mock_log_info:
+ hook.check_query_output(query_ids)
+ mock_log_info.assert_called_with(GET_RESPONSE)
+
+ @pytest.mark.parametrize(
+ "query_ids",
+ [
+ (["uuid", "uuid1"]),
+ ],
+ )
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests.get")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+ "get_request_url_header_params"
+ )
+ def test_check_query_output_exception(self, mock_geturl_header_params,
mock_requests, query_ids):
+ """
+ Test check_query_output by passing query ids as params and mock
get_request_url_header_params
+ to raise airflow exception and mock with http error
+ """
+ req_id = uuid.uuid4()
+ params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+ mock_geturl_header_params.return_value = HEADERS, params,
"/test/airflow/"
+ mock_resp = requests.models.Response()
+ mock_resp.status_code = 404
+ mock_requests.return_value = mock_resp
+ hook = SnowflakeSqlApiHook("mock_conn_id")
+ with mock.patch.object(hook.log, "error"):
+ with pytest.raises(AirflowException) as airflow_exception:
+ hook.check_query_output(query_ids)
+ assert airflow_exception
+
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+ def test_get_request_url_header_params(self, mock_get_header,
mock_conn_param):
+ """Test get_request_url_header_params by mocking _get_conn_params and
get_headers"""
+ mock_conn_param.return_value = CONN_PARAMS
+ mock_get_header.return_value = HEADERS
+ hook = SnowflakeSqlApiHook("mock_conn_id")
+ header, params, url = hook.get_request_url_header_params("uuid")
+ assert header == HEADERS
+ assert url ==
"https://airflow.af_region.snowflakecomputing.com/api/v2/statements/uuid"
+
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key")
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+
@mock.patch("airflow.providers.snowflake.hooks.sql_api_generate_jwt.JWTGenerator.get_token")
+ def test_get_headers(self, mock_get_token, mock_conn_param,
mock_private_key):
+ """Test get_headers method by mocking get_private_key and
_get_conn_params method"""
+ mock_get_token.return_value = "newT0k3n"
+ mock_conn_param.return_value = CONN_PARAMS
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
+ result = hook.get_headers()
+ assert result == HEADERS
+
+ @pytest.fixture()
+ def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
+ """Encrypt the pem file from the path"""
+ key = rsa.generate_private_key(backend=default_backend(),
public_exponent=65537, key_size=2048)
+ private_key = key.private_bytes(
+ serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8,
serialization.NoEncryption()
+ )
+ test_key_file = tmp_path / "test_key.pem"
+ test_key_file.write_bytes(private_key)
+ return test_key_file
+
+ @pytest.fixture()
+ def encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
+ """Encrypt private key from the temp path"""
+ key = rsa.generate_private_key(backend=default_backend(),
public_exponent=65537, key_size=2048)
+ private_key = key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.PKCS8,
+
encryption_algorithm=serialization.BestAvailableEncryption(_PASSWORD.encode()),
+ )
+ test_key_file: Path = tmp_path / "test_key.p8"
+ test_key_file.write_bytes(private_key)
+ return test_key_file
+
+ def test_get_private_key_should_support_private_auth_in_connection(
+ self, encrypted_temporary_private_key: Path
+ ):
+ """Test get_private_key function with private_key_content in
connection"""
+ connection_kwargs: Any = {
+ **BASE_CONNECTION_KWARGS,
+ "password": _PASSWORD,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_content":
str(encrypted_temporary_private_key.read_text()),
+ },
+ }
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ):
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook.get_private_key()
+ assert hook.private_key is not None
+
+ def test_get_private_key_raise_exception(self,
encrypted_temporary_private_key: Path):
+ """
+ Test get_private_key function with private_key_content and
private_key_file in connection
+ and raise airflow exception
+ """
+ connection_kwargs: Any = {
+ **BASE_CONNECTION_KWARGS,
+ "password": _PASSWORD,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_content":
str(encrypted_temporary_private_key.read_text()),
+ "private_key_file": str(encrypted_temporary_private_key),
+ },
+ }
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ), pytest.raises(
+ AirflowException,
+ match="The private_key_file and private_key_content extra fields
are mutually "
+ "exclusive. Please remove one.",
+ ):
+ hook.get_private_key()
+
+ def test_get_private_key_should_support_private_auth_with_encrypted_key(
+ self, encrypted_temporary_private_key
+ ):
+ """Test get_private_key method by supporting for private auth
encrypted_key"""
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "password": _PASSWORD,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_file": str(encrypted_temporary_private_key),
+ },
+ }
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ):
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook.get_private_key()
+ assert hook.private_key is not None
+
+ def test_get_private_key_should_support_private_auth_with_unencrypted_key(
+ self,
+ non_encrypted_temporary_private_key,
+ ):
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "password": None,
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "private_key_file": str(non_encrypted_temporary_private_key),
+ },
+ }
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ):
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook.get_private_key()
+ assert hook.private_key is not None
+ connection_kwargs["password"] = ""
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ):
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ hook.get_private_key()
+ assert hook.private_key is not None
+ connection_kwargs["password"] = _PASSWORD
+ with unittest.mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+ ), pytest.raises(TypeError, match="Password was given but private key
is not encrypted."):
+
SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "status_code,response,expected_response",
+ [
+ (
+ 200,
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statementHandle": "uuid",
+ },
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statement_handles": ["uuid"],
+ },
+ ),
+ (
+ 200,
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statementHandles": ["uuid", "uuid1"],
+ },
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statement_handles": ["uuid", "uuid1"],
+ },
+ ),
+ (202, {}, {"status": "running", "message": "Query statements are
still running"}),
+ (422, {"status": "error", "message": "test"}, {"status": "error",
"message": "test"}),
+ (404, {"status": "error", "message": "test"}, {"status": "error",
"message": "test"}),
+ ],
+ )
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+ "get_request_url_header_params"
+ )
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+ def test_get_sql_api_query_status(
+ self, mock_requests, mock_geturl_header_params, status_code, response,
expected_response
+ ):
+ """Test get_sql_api_query_status function by mocking the status,
response and expected
+ response"""
+ req_id = uuid.uuid4()
+ params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+ mock_geturl_header_params.return_value = HEADERS, params,
"/test/airflow/"
+
+ class MockResponse:
+ def __init__(self, status_code, data):
+ self.status_code = status_code
+ self.data = data
+
+ def json(self):
+ return self.data
+
+ mock_requests.get.return_value = MockResponse(status_code, response)
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ assert hook.get_sql_api_query_status("uuid") == expected_response
diff --git a/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py
b/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py
new file mode 100644
index 0000000000..5fa434d9b3
--- /dev/null
+++ b/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from cryptography.hazmat.backends import default_backend as
crypto_default_backend
+from cryptography.hazmat.primitives import serialization as
crypto_serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+from airflow.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator
+
+_PASSWORD = "snowflake42"
+
+key = rsa.generate_private_key(backend=crypto_default_backend(),
public_exponent=65537, key_size=2048)
+
+private_key = key.private_bytes(
+ crypto_serialization.Encoding.PEM,
+ crypto_serialization.PrivateFormat.PKCS8,
+ crypto_serialization.NoEncryption(),
+)
+
+
+class TestJWTGenerator:
+ @pytest.mark.parametrize(
+ "account_name, expected_account_name",
+ [("test.us-east-1", "TEST"), ("test.global", "TEST.GLOBAL"), ("test",
"TEST")],
+ )
+ def test_prepare_account_name_for_jwt(self, account_name,
expected_account_name):
+ """
+ Test prepare_account_name_for_jwt by passing the account identifier and
+ get the proper account name in caps
+ """
+ jwt_generator = JWTGenerator(account_name, "test_user", private_key)
+ response = jwt_generator.prepare_account_name_for_jwt(account_name)
+ assert response == expected_account_name
+
+ def test_calculate_public_key_fingerprint(self):
+ """Asserting get_token and calculate_public_key_fingerprint by passing
key and generating token"""
+ jwt_generator = JWTGenerator("test.us-east-1", "test_user", key)
+ assert jwt_generator.get_token()
diff --git a/tests/providers/snowflake/operators/test_snowflake.py
b/tests/providers/snowflake/operators/test_snowflake.py
index dde2f9adde..8f32c6e62d 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -21,11 +21,13 @@ from unittest import mock
import pytest
+from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
from airflow.providers.snowflake.operators.snowflake import (
SnowflakeCheckOperator,
SnowflakeIntervalCheckOperator,
SnowflakeOperator,
+ SnowflakeSqlApiOperator,
SnowflakeValueCheckOperator,
)
from airflow.utils import timezone
@@ -35,6 +37,17 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = "unit_test_dag"
+TASK_ID = "snowflake_check"
+CONN_ID = "my_snowflake_conn"
+TEST_SQL = "select * from any;"
+
+SQL_MULTIPLE_STMTS = (
+ "create or replace table user_test (i int); insert into user_test (i) "
+ "values (200); insert into user_test (i) values (300); select i from
user_test order by i;"
+)
+
+SINGLE_STMT = "select i from user_test order by i;"
+
class TestSnowflakeOperator:
def setup_method(self):
@@ -73,3 +86,59 @@ class TestSnowflakeCheckOperators:
operator = operator_class(task_id="snowflake_check",
snowflake_conn_id="snowflake_default", **kwargs)
operator.get_db_hook()
mock_get_db_hook.assert_called_once()
+
+
+class TestSnowflakeSqlApiOperator:
+ @pytest.fixture
+ def mock_execute_query(self):
+ with mock.patch(
+
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.execute_query"
+ ) as execute_query:
+ yield execute_query
+
+ @pytest.fixture
+ def mock_get_sql_api_query_status(self):
+ with mock.patch(
+
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.get_sql_api_query_status"
+ ) as get_sql_api_query_status:
+ yield get_sql_api_query_status
+
+ @pytest.fixture
+ def mock_check_query_output(self):
+ with mock.patch(
+
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.check_query_output"
+ ) as check_query_output:
+ yield check_query_output
+
+ def test_snowflake_sql_api_to_succeed_when_no_query_fails(
+ self, mock_execute_query, mock_get_sql_api_query_status,
mock_check_query_output
+ ):
+ """Tests SnowflakeSqlApiOperator passed if poll_on_queries method
gives no error"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id="snowflake_default",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ do_xcom_push=False,
+ )
+ mock_execute_query.return_value = ["uuid1", "uuid2"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "success"},
{"status": "success"}]
+ operator.execute(context=None)
+
+ def test_snowflake_sql_api_to_fails_when_one_query_fails(
+ self, mock_execute_query, mock_get_sql_api_query_status
+ ):
+ """Tests SnowflakeSqlApiOperator passed if poll_on_queries method
gives one or more error"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id="snowflake_default",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ do_xcom_push=False,
+ )
+ mock_execute_query.return_value = ["uuid1", "uuid2"]
+ mock_get_sql_api_query_status.side_effect = [{"status": "error"},
{"status": "success"}]
+ with pytest.raises(AirflowException):
+ operator.execute(context=None)
diff --git a/tests/system/providers/snowflake/example_snowflake.py
b/tests/system/providers/snowflake/example_snowflake.py
index ac174f7b59..007dcdb4ab 100644
--- a/tests/system/providers/snowflake/example_snowflake.py
+++ b/tests/system/providers/snowflake/example_snowflake.py
@@ -24,7 +24,7 @@ import os
from datetime import datetime
from airflow import DAG
-from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
+from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator,
SnowflakeSqlApiOperator
SNOWFLAKE_CONN_ID = "my_snowflake_conn"
SNOWFLAKE_SAMPLE_TABLE = "sample_table"
@@ -72,6 +72,14 @@ with DAG(
# [END howto_operator_snowflake]
+ # [START howto_snowflake_sql_api_operator]
+ snowflake_sql_api_op_sql_multiple_stmt = SnowflakeSqlApiOperator(
+ task_id="snowflake_op_sql_multiple_stmt",
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=len(SQL_LIST),
+ )
+ # [END howto_snowflake_sql_api_operator]
+
(
snowflake_op_sql_str
>> [
@@ -79,6 +87,7 @@ with DAG(
snowflake_op_sql_list,
snowflake_op_template_file,
snowflake_op_sql_multiple_stmts,
+ snowflake_sql_api_op_sql_multiple_stmt,
]
)