This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 739e6b5d77 Add SnowflakeSqlApiOperator operator (#30698)
739e6b5d77 is described below

commit 739e6b5d775412f987a3ff5fb71c51fbb7051a89
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Wed May 24 12:40:03 2023 +0530

    Add SnowflakeSqlApiOperator operator (#30698)
    
    * Add SnowflakeSqlApiOperatorAsyn
    
    * Remove unrelated change
    
    * Saving work
    
    * Add SnowflakeSqlApiOperator operator
    
    * Remove snowflake_trigger
    
    * Remove snowflake_trigger
    
    * Remove unwanted code
    
    * Remove unwanted code
    
    * Remove unwanted code
    
    * Amend test cases
    
    * Remove unwanted code
    
    * Fix static checks
    
    * Fix docs
    
    * Update airflow/providers/snowflake/hooks/snowflake_sql_api.py
    
    * Remove unwanted code
    
    * Remove unwanted code
    
    * Fix static check
    
    * Add docs and example
    
    * Fix static check
---
 .../providers/snowflake/hooks/snowflake_sql_api.py | 256 ++++++++++++
 .../snowflake/hooks/sql_api_generate_jwt.py        | 174 ++++++++
 airflow/providers/snowflake/operators/snowflake.py | 167 +++++++-
 airflow/providers/snowflake/provider.yaml          |   2 +
 .../operators/snowflake.rst                        |  43 ++
 .../snowflake/hooks/test_snowflake_sql_api.py      | 458 +++++++++++++++++++++
 .../snowflake/hooks/test_sql_api_generate_jwt.py   |  54 +++
 .../snowflake/operators/test_snowflake.py          |  69 ++++
 .../providers/snowflake/example_snowflake.py       |  11 +-
 9 files changed, 1232 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py 
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
new file mode 100644
index 0000000000..e77b7607b7
--- /dev/null
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -0,0 +1,256 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import uuid
+from datetime import timedelta
+from pathlib import Path
+from typing import Any
+
+import requests
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+
+from airflow import AirflowException
+from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
+from airflow.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator
+
+
+class SnowflakeSqlApiHook(SnowflakeHook):
+    """
+    A client to interact with Snowflake using SQL API  and allows submitting
+    multiple SQL statements in a single request. In combination with aiohttp, 
make post request to submit SQL
+    statements for execution, poll to check the status of the execution of a 
statement. Fetch query results
+    asynchronously.
+
+    This hook requires the snowflake_conn_id connection. This hooks mainly 
uses account, schema, database,
+     warehouse, private_key_file or private_key_content field must be setup in 
the connection. Other inputs
+      can be defined in the connection or hook instantiation.
+
+    :param snowflake_conn_id: Reference to
+        :ref:`Snowflake connection id<howto/connection:snowflake>`
+    :param account: snowflake account name
+    :param authenticator: authenticator for Snowflake.
+        'snowflake' (default) to use the internal Snowflake authenticator
+        'externalbrowser' to authenticate using your web browser and
+        Okta, ADFS or any other SAML 2.0-compliant identify provider
+        (IdP) that has been defined for your account
+        'https://<your_okta_account_name>.okta.com' to authenticate
+        through native Okta.
+    :param warehouse: name of snowflake warehouse
+    :param database: name of snowflake database
+    :param region: name of snowflake region
+    :param role: name of snowflake role
+    :param schema: name of snowflake schema
+    :param session_parameters: You can set session-level parameters at
+        the time you connect to Snowflake
+    :param token_life_time: lifetime of the JWT Token in timedelta
+    :param token_renewal_delta: Renewal time of the JWT Token in  timedelta
+    """
+
+    LIFETIME = timedelta(minutes=59)  # The tokens will have a 59 minute 
lifetime
+    RENEWAL_DELTA = timedelta(minutes=54)  # Tokens will be renewed after 54 
minutes
+
+    def __init__(
+        self,
+        snowflake_conn_id: str,
+        token_life_time: timedelta = LIFETIME,
+        token_renewal_delta: timedelta = RENEWAL_DELTA,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        self.snowflake_conn_id = snowflake_conn_id
+        self.token_life_time = token_life_time
+        self.token_renewal_delta = token_renewal_delta
+        super().__init__(snowflake_conn_id, *args, **kwargs)
+        self.private_key: Any = None
+
+    def get_private_key(self) -> None:
+        """Gets the private key from snowflake connection"""
+        conn = self.get_connection(self.snowflake_conn_id)
+
+        # If private_key_file is specified in the extra json, load the 
contents of the file as a private key.
+        # If private_key_content is specified in the extra json, use it as a 
private key.
+        # As a next step, specify this private key in the connection 
configuration.
+        # The connection password then becomes the passphrase for the private 
key.
+        # If your private key is not encrypted (not recommended), then leave 
the password empty.
+
+        private_key_file = conn.extra_dejson.get(
+            "extra__snowflake__private_key_file"
+        ) or conn.extra_dejson.get("private_key_file")
+        private_key_content = conn.extra_dejson.get(
+            "extra__snowflake__private_key_content"
+        ) or conn.extra_dejson.get("private_key_content")
+
+        private_key_pem = None
+        if private_key_content and private_key_file:
+            raise AirflowException(
+                "The private_key_file and private_key_content extra fields are 
mutually exclusive. "
+                "Please remove one."
+            )
+        elif private_key_file:
+            private_key_pem = Path(private_key_file).read_bytes()
+        elif private_key_content:
+            private_key_pem = private_key_content.encode()
+
+        if private_key_pem:
+            passphrase = None
+            if conn.password:
+                passphrase = conn.password.strip().encode()
+
+            self.private_key = serialization.load_pem_private_key(
+                private_key_pem, password=passphrase, backend=default_backend()
+            )
+
+    def execute_query(
+        self, sql: str, statement_count: int, query_tag: str = "", bindings: 
dict[str, Any] | None = None
+    ) -> list[str]:
+        """
+        Using SnowflakeSQL API, run the query in snowflake by making API 
request
+
+        :param sql: the sql string to be executed with possibly multiple 
statements
+        :param statement_count: set the MULTI_STATEMENT_COUNT field to the 
number of SQL statements
+         in the request
+        :param query_tag: (Optional) Query tag that you want to associate with 
the SQL statement.
+            For details, see 
https://docs.snowflake.com/en/sql-reference/parameters.html#label-query-tag
+            parameter.
+        :param bindings: (Optional) Values of bind variables in the SQL 
statement.
+            When executing the statement, Snowflake replaces placeholders (? 
and :name) in
+            the statement with these specified values.
+        """
+        conn_config = self._get_conn_params()
+
+        req_id = uuid.uuid4()
+        url = 
f"https://{conn_config['account']}.{conn_config['region']}.snowflakecomputing.com/api/v2/statements"
+        params: dict[str, Any] | None = {"requestId": str(req_id), "async": 
True, "pageSize": 10}
+        headers = self.get_headers()
+        if bindings is None:
+            bindings = {}
+        data = {
+            "statement": sql,
+            "resultSetMetaData": {"format": "json"},
+            "database": conn_config["database"],
+            "schema": conn_config["schema"],
+            "warehouse": conn_config["warehouse"],
+            "role": conn_config["role"],
+            "bindings": bindings,
+            "parameters": {
+                "MULTI_STATEMENT_COUNT": statement_count,
+                "query_tag": query_tag,
+            },
+        }
+        response = requests.post(url, json=data, headers=headers, 
params=params)
+        try:
+            response.raise_for_status()
+        except requests.exceptions.HTTPError as e:  # pragma: no cover
+            raise AirflowException(f"Response: {e.response.content} Status 
Code: {e.response.status_code}")
+        json_response = response.json()
+        self.log.info("Snowflake SQL POST API response: %s", json_response)
+        if "statementHandles" in json_response:
+            self.query_ids = json_response["statementHandles"]
+        elif "statementHandle" in json_response:
+            self.query_ids.append(json_response["statementHandle"])
+        else:
+            raise AirflowException("No statementHandle/statementHandles 
present in response")
+        return self.query_ids
+
+    def get_headers(self) -> dict[str, Any]:
+        """Based on the private key, and with connection details JWT Token is 
generated and header
+        is formed
+        """
+        if not self.private_key:
+            self.get_private_key()
+        conn_config = self._get_conn_params()
+
+        # Get the JWT token from the connection details and the private key
+        token = JWTGenerator(
+            conn_config["account"],  # type: ignore[arg-type]
+            conn_config["user"],  # type: ignore[arg-type]
+            private_key=self.private_key,
+            lifetime=self.token_life_time,
+            renewal_delay=self.token_renewal_delta,
+        ).get_token()
+
+        headers = {
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {token}",
+            "Accept": "application/json",
+            "User-Agent": "snowflakeSQLAPI/1.0",
+            "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+        }
+        return headers
+
+    def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, 
Any], dict[str, Any], str]:
+        """
+        Build the request header Url with account name identifier and query id 
from the connection params
+
+        :param query_id: statement handles query ids for the individual 
statements.
+        """
+        conn_config = self._get_conn_params()
+        req_id = uuid.uuid4()
+        header = self.get_headers()
+        params = {"requestId": str(req_id)}
+        url = 
f"https://{conn_config['account']}.{conn_config['region']}.snowflakecomputing.com/api/v2/statements/{query_id}"
+        return header, params, url
+
+    def check_query_output(self, query_ids: list[str]) -> None:
+        """
+        Based on the query ids passed as the parameter make HTTP request to 
snowflake SQL API and logs
+         the response
+
+        :param query_ids: statement handles query id for the individual 
statements.
+        """
+        for query_id in query_ids:
+            header, params, url = self.get_request_url_header_params(query_id)
+            try:
+                response = requests.get(url, headers=header, params=params)
+                response.raise_for_status()
+                self.log.info(response.json())
+            except requests.exceptions.HTTPError as e:
+                raise AirflowException(
+                    f"Response: {e.response.content}, Status Code: 
{e.response.status_code}"
+                )
+
+    def get_sql_api_query_status(self, query_id: str) -> dict[str, str | 
list[str]]:
+        """
+        Based on the query id async HTTP request is made to snowflake SQL API 
and return response.
+
+        :param query_id: statement handle id for the individual statements.
+        """
+        self.log.info("Retrieving status for query id %s", {query_id})
+        header, params, url = self.get_request_url_header_params(query_id)
+        response = requests.get(url, params=params, headers=header)
+        status_code = response.status_code
+        resp = response.json()
+        self.log.info("Snowflake SQL GET statements status API response: %s", 
resp)
+        if status_code == 202:
+            return {"status": "running", "message": "Query statements are 
still running"}
+        elif status_code == 422:
+            return {"status": "error", "message": resp["message"]}
+        elif status_code == 200:
+            statement_handles = []
+            if "statementHandles" in resp and resp["statementHandles"]:
+                statement_handles = resp["statementHandles"]
+            elif "statementHandle" in resp and resp["statementHandle"]:
+                statement_handles.append(resp["statementHandle"])
+            return {
+                "status": "success",
+                "message": resp["message"],
+                "statement_handles": statement_handles,
+            }
+        else:
+            return {"status": "error", "message": resp["message"]}
diff --git a/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py 
b/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py
new file mode 100644
index 0000000000..883ef3aae0
--- /dev/null
+++ b/airflow/providers/snowflake/hooks/sql_api_generate_jwt.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import base64
+import hashlib
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+# This class relies on the PyJWT module (https://pypi.org/project/PyJWT/).
+import jwt
+from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
+
+logger = logging.getLogger(__name__)
+
+
+ISSUER = "iss"
+EXPIRE_TIME = "exp"
+ISSUE_TIME = "iat"
+SUBJECT = "sub"
+
+# If you generated an encrypted private key, implement this method to return
+# the passphrase for decrypting your private key. As an example, this function
+# prompts the user for the passphrase.
+
+
+class JWTGenerator:
+    """
+    Creates and signs a JWT with the specified private key file, username, and 
account identifier.
+    The JWTGenerator keeps the generated token and only regenerates the token 
if a specified period
+    of time has passed.
+
+    Creates an object that generates JWTs for the specified user, account 
identifier, and private key
+
+    :param account: Your Snowflake account identifier.
+        See 
https://docs.snowflake.com/en/user-guide/admin-account-identifier.html. Note 
that if you are using
+        the account locator, exclude any region information from the account 
locator.
+    :param user: The Snowflake username.
+    :param private_key: Private key from the file path for signing the JWTs.
+    :param lifetime: The number of minutes (as a timedelta) during which the 
key will be valid.
+    :param renewal_delay: The number of minutes (as a timedelta) from now 
after which the JWT
+        generator should renew the JWT.
+    """
+
+    LIFETIME = timedelta(minutes=59)  # The tokens will have a 59 minute 
lifetime
+    RENEWAL_DELTA = timedelta(minutes=54)  # Tokens will be renewed after 54 
minutes
+    ALGORITHM = "RS256"  # Tokens will be generated using RSA with SHA256
+
+    def __init__(
+        self,
+        account: str,
+        user: str,
+        private_key: Any,
+        lifetime: timedelta = LIFETIME,
+        renewal_delay: timedelta = RENEWAL_DELTA,
+    ):
+        logger.info(
+            """Creating JWTGenerator with arguments
+            account : %s, user : %s, lifetime : %s, renewal_delay : %s""",
+            account,
+            user,
+            lifetime,
+            renewal_delay,
+        )
+
+        # Construct the fully qualified name of the user in uppercase.
+        self.account = self.prepare_account_name_for_jwt(account)
+        self.user = user.upper()
+        self.qualified_username = self.account + "." + self.user
+
+        self.lifetime = lifetime
+        self.renewal_delay = renewal_delay
+        self.private_key = private_key
+        self.renew_time = datetime.now(timezone.utc)
+        self.token: str | None = None
+
+    def prepare_account_name_for_jwt(self, raw_account: str) -> str:
+        """
+        Prepare the account identifier for use in the JWT.
+        For the JWT, the account identifier must not include the subdomain or 
any region or cloud provider
+        information.
+
+        :param raw_account: The specified account identifier.
+        """
+        account = raw_account
+        if ".global" not in account:
+            # Handle the general case.
+            idx = account.find(".")
+            if idx > 0:
+                account = account[0:idx]
+        else:
+            # Handle the replication case.
+            idx = account.find("-")
+            if idx > 0:
+                account = account[0:idx]  # pragma: no cover
+        # Use uppercase for the account identifier.
+        return account.upper()
+
+    def get_token(self) -> str | None:
+        """
+        Generates a new JWT. If a JWT has been already been generated earlier, 
return the previously
+        generated token unless the specified renewal time has passed.
+        """
+        now = datetime.now(timezone.utc)  # Fetch the current time
+
+        # If the token has expired or doesn't exist, regenerate the token.
+        if self.token is None or self.renew_time <= now:
+            logger.info(
+                "Generating a new token because the present time (%s) is later 
than the renewal time (%s)",
+                now,
+                self.renew_time,
+            )
+            # Calculate the next time we need to renew the token.
+            self.renew_time = now + self.renewal_delay
+
+            # Prepare the fields for the payload.
+            # Generate the public key fingerprint for the issuer in the 
payload.
+            public_key_fp = 
self.calculate_public_key_fingerprint(self.private_key)
+
+            # Create our payload
+            payload = {
+                # Set the issuer to the fully qualified username concatenated 
with the public key fingerprint.
+                ISSUER: self.qualified_username + "." + public_key_fp,
+                # Set the subject to the fully qualified username.
+                SUBJECT: self.qualified_username,
+                # Set the issue time to now.
+                ISSUE_TIME: now,
+                # Set the expiration time, based on the lifetime specified for 
this object.
+                EXPIRE_TIME: now + self.lifetime,
+            }
+
+            # Regenerate the actual token
+            token = jwt.encode(payload, key=self.private_key, 
algorithm=JWTGenerator.ALGORITHM)
+
+            if isinstance(token, bytes):
+                token = token.decode("utf-8")
+            self.token = token
+
+        return self.token
+
+    def calculate_public_key_fingerprint(self, private_key: Any) -> str:
+        """
+        Given a private key in PEM format, return the public key fingerprint.
+
+        :param private_key: private key
+        """
+        # Get the raw bytes of public key.
+        public_key_raw = private_key.public_key().public_bytes(
+            Encoding.DER, PublicFormat.SubjectPublicKeyInfo
+        )
+
+        # Get the sha256 hash of the raw bytes.
+        sha256hash = hashlib.sha256()
+        sha256hash.update(public_key_raw)
+
+        # Base64-encode the value and prepend the prefix 'SHA256:'.
+        public_key_fp = "SHA256:" + 
base64.b64encode(sha256hash.digest()).decode("utf-8")
+
+        return public_key_fp
diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index f09039baaf..2257ec759e 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -17,9 +17,12 @@
 # under the License.
 from __future__ import annotations
 
+import time
 import warnings
-from typing import Any, Iterable, Mapping, Sequence, SupportsAbs
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
 
+from airflow import AirflowException
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.common.sql.operators.sql import (
     SQLCheckOperator,
@@ -27,6 +30,12 @@ from airflow.providers.common.sql.operators.sql import (
     SQLIntervalCheckOperator,
     SQLValueCheckOperator,
 )
+from airflow.providers.snowflake.hooks.snowflake_sql_api import (
+    SnowflakeSqlApiHook,
+)
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
 
 
 class SnowflakeOperator(SQLExecuteQueryOperator):
@@ -359,3 +368,159 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
         self.authenticator = authenticator
         self.session_parameters = session_parameters
         self.query_ids: list[str] = []
+
+
+class SnowflakeSqlApiOperator(SnowflakeOperator):
+    """
+    Implemented Snowflake SQL API Operator to support multiple SQL statements 
sequentially,
+    which is the behavior of the SnowflakeOperator, the Snowflake SQL API 
allows submitting
+    multiple SQL statements in a single request. It make post request to 
submit SQL
+    statements for execution, poll to check the status of the execution of a 
statement. Fetch query results
+    concurrently.
+    This Operator currently uses key pair authentication, so you need to 
provide private key raw content or
+    private key file path in the snowflake connection along with other details
+
+    .. seealso::
+
+        `Snowflake SQL API key pair Authentication 
<https://docs.snowflake.com/en/developer-guide/sql-api/authenticating.html#label-sql-api-authenticating-key-pair>`_
+
+    Where can this operator fit in?
+         - To execute multiple SQL statements in a single request
+         - To execute the SQL statement asynchronously and to execute standard 
queries and most DDL and DML statements
+         - To develop custom applications and integrations that perform queries
+         - To create provision users and roles, create table, etc.
+
+    The following commands are not supported:
+        - The PUT command (in Snowflake SQL)
+        - The GET command (in Snowflake SQL)
+        - The CALL command with stored procedures that return a table(stored 
procedures with the RETURNS TABLE clause).
+
+    .. seealso::
+
+        - `Snowflake SQL API 
<https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#introduction-to-the-sql-api>`_
+        - `API Reference 
<https://docs.snowflake.com/en/developer-guide/sql-api/reference.html#snowflake-sql-api-reference>`_
+        - `Limitation on snowflake SQL API 
<https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#limitations-of-the-sql-api>`_
+
+    :param snowflake_conn_id: Reference to Snowflake connection id
+    :param sql: the sql code to be executed. (templated)
+    :param autocommit: if True, each command is automatically committed.
+        (default value: True)
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :param warehouse: name of warehouse (will overwrite any warehouse
+        defined in the connection's extra JSON)
+    :param database: name of database (will overwrite database defined
+        in connection)
+    :param schema: name of schema (will overwrite schema defined in
+        connection)
+    :param role: name of role (will overwrite any role defined in
+        connection's extra JSON)
+    :param authenticator: authenticator for Snowflake.
+        'snowflake' (default) to use the internal Snowflake authenticator
+        'externalbrowser' to authenticate using your web browser and
+        Okta, ADFS or any other SAML 2.0-compliant identify provider
+        (IdP) that has been defined for your account
+        'https://<your_okta_account_name>.okta.com' to authenticate
+        through native Okta.
+    :param session_parameters: You can set session-level parameters at
+        the time you connect to Snowflake
+    :param poll_interval: the interval in seconds to poll the query
+    :param statement_count: Number of SQL statement to be executed
+    :param token_life_time: lifetime of the JWT Token
+    :param token_renewal_delta: Renewal time of the JWT Token
+    :param bindings: (Optional) Values of bind variables in the SQL statement.
+            When executing the statement, Snowflake replaces placeholders (? 
and :name) in
+            the statement with these specified values.
+    """  # noqa
+
+    LIFETIME = timedelta(minutes=59)  # The tokens will have a 59 minutes 
lifetime
+    RENEWAL_DELTA = timedelta(minutes=54)  # Tokens will be renewed after 54 
minutes
+
+    def __init__(
+        self,
+        *,
+        snowflake_conn_id: str = "snowflake_default",
+        warehouse: str | None = None,
+        database: str | None = None,
+        role: str | None = None,
+        schema: str | None = None,
+        authenticator: str | None = None,
+        session_parameters: dict[str, Any] | None = None,
+        poll_interval: int = 5,
+        statement_count: int = 0,
+        token_life_time: timedelta = LIFETIME,
+        token_renewal_delta: timedelta = RENEWAL_DELTA,
+        bindings: dict[str, Any] | None = None,
+        **kwargs: Any,
+    ) -> None:
+        self.snowflake_conn_id = snowflake_conn_id
+        self.poll_interval = poll_interval
+        self.statement_count = statement_count
+        self.token_life_time = token_life_time
+        self.token_renewal_delta = token_renewal_delta
+        self.bindings = bindings
+        self.execute_async = False
+        if self.__class__.__base__.__name__ != "SnowflakeOperator":
+            # It's better to do str check of the parent class name because 
currently SnowflakeOperator
+            # is deprecated and in future OSS SnowflakeOperator may be removed
+            if any(
+                [warehouse, database, role, schema, authenticator, 
session_parameters]
+            ):  # pragma: no cover
+                hook_params = kwargs.pop("hook_params", {})  # pragma: no cover
+                kwargs["hook_params"] = {
+                    "warehouse": warehouse,
+                    "database": database,
+                    "role": role,
+                    "schema": schema,
+                    "authenticator": authenticator,
+                    "session_parameters": session_parameters,
+                    **hook_params,
+                }
+            super().__init__(conn_id=snowflake_conn_id, **kwargs)  # pragma: 
no cover
+        else:
+            super().__init__(**kwargs)
+
+    def execute(self, context: Context) -> None:
+        """
+        Make a POST API request to snowflake by using SnowflakeSQL and execute 
the query to get the ids.
+        By deferring the SnowflakeSqlApiTrigger class passed along with query 
ids.
+        """
+        self.log.info("Executing: %s", self.sql)
+        self._hook = SnowflakeSqlApiHook(
+            snowflake_conn_id=self.snowflake_conn_id,
+            token_life_time=self.token_life_time,
+            token_renewal_delta=self.token_renewal_delta,
+        )
+        self.query_ids = self._hook.execute_query(
+            self.sql, statement_count=self.statement_count, 
bindings=self.bindings  # type: ignore[arg-type]
+        )
+        self.log.info("List of query ids %s", self.query_ids)
+
+        if self.do_xcom_push:
+            context["ti"].xcom_push(key="query_ids", value=self.query_ids)
+
+        statement_status = self.poll_on_queries()
+        if statement_status["error"]:
+            raise AirflowException(statement_status["error"])
+        self._hook.check_query_output(self.query_ids)
+
+    def poll_on_queries(self):
+        """Poll on requested queries"""
+        queries_in_progress = set(self.query_ids)
+        statement_success_status = {}
+        statement_error_status = {}
+        for query_id in self.query_ids:
+            if not len(queries_in_progress):
+                break
+            self.log.info("checking : %s", query_id)
+            try:
+                statement_status = 
self._hook.get_sql_api_query_status(query_id)
+            except Exception as e:
+                raise ValueError({"status": "error", "message": str(e)})
+            if statement_status.get("status") == "error":
+                queries_in_progress.remove(query_id)
+                statement_error_status[query_id] = statement_status
+            if statement_status.get("status") == "success":
+                statement_success_status[query_id] = statement_status
+                queries_in_progress.remove(query_id)
+            time.sleep(self.poll_interval)
+        return {"success": statement_success_status, "error": 
statement_error_status}
diff --git a/airflow/providers/snowflake/provider.yaml 
b/airflow/providers/snowflake/provider.yaml
index 62af89dfc3..5e4d04a3f5 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -75,6 +75,8 @@ hooks:
   - integration-name: Snowflake
     python-modules:
       - airflow.providers.snowflake.hooks.snowflake
+      - airflow.providers.snowflake.hooks.snowflake_sql_api
+      - airflow.providers.snowflake.hooks.sql_api_generate_jwt
 
 transfers:
   - source-integration-name: Amazon Simple Storage Service (S3)
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst 
b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
index f823430fa8..1e80f3af29 100644
--- a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+++ b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
@@ -58,3 +58,46 @@ An example usage of the SnowflakeOperator is as follows:
 
   Parameters that can be passed onto the operator will be given priority over 
the parameters already given
   in the Airflow connection metadata (such as ``schema``, ``role``, 
``database`` and so forth).
+
+
+SnowflakeSqlApiOperator
+=======================
+
+Use the :class:`SnowflakeSqlApiHook 
<airflow.providers.snowflake.operators.snowflake>` to execute
+SQL commands in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to connect to your Snowflake instance 
where
+the connection metadata is structured as follows:
+
+.. list-table:: Snowflake Airflow Connection Metadata
+   :widths: 25 25
+   :header-rows: 1
+
+   * - Parameter
+     - Input
+   * - Login: string
+     - Snowflake user name
+   * - Password: string
+     - Password for Snowflake user
+   * - Schema: string
+     - Set schema to execute SQL operations on by default
+   * - Extra: dictionary
+     - ``warehouse``, ``account``, ``database``, ``region``, ``role``, 
``authenticator``
+
+An example usage of the SnowflakeSqlApiHook is as follows:
+
+.. exampleinclude:: 
/../../tests/system/providers/snowflake/example_snowflake.py
+    :language: python
+    :start-after: [START howto_snowflake_sql_api_operator]
+    :end-before: [END howto_snowflake_sql_api_operator]
+    :dedent: 4
+
+
+.. note::
+
+  Parameters that can be passed onto the operator will be given priority over 
the parameters already given
+  in the Airflow connection metadata (such as ``schema``, ``role``, 
``database`` and so forth).
diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py 
b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
new file mode 100644
index 0000000000..8f68f446c8
--- /dev/null
+++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
@@ -0,0 +1,458 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import unittest
+import uuid
+from pathlib import Path
+from typing import Any
+from unittest import mock
+
+import pytest
+import requests
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+from airflow import AirflowException
+from airflow.models import Connection
+from airflow.providers.snowflake.hooks.snowflake_sql_api import (
+    SnowflakeSqlApiHook,
+)
+
+SQL_MULTIPLE_STMTS = (
+    "create or replace table user_test (i int); insert into user_test (i) "
+    "values (200); insert into user_test (i) values (300); select i from 
user_test order by i;"
+)
+_PASSWORD = "snowflake42"
+
+SINGLE_STMT = "select i from user_test order by i;"
+BASE_CONNECTION_KWARGS: dict = {
+    "login": "user",
+    "conn_type": "snowflake",
+    "password": "pw",
+    "schema": "public",
+    "extra": {
+        "database": "db",
+        "account": "airflow",
+        "warehouse": "af_wh",
+        "region": "af_region",
+        "role": "af_role",
+    },
+}
+CONN_PARAMS = {
+    "account": "airflow",
+    "application": "AIRFLOW",
+    "authenticator": "snowflake",
+    "database": "db",
+    "password": "pw",
+    "region": "af_region",
+    "role": "af_role",
+    "schema": "public",
+    "session_parameters": None,
+    "user": "user",
+    "warehouse": "af_wh",
+}
+HEADERS = {
+    "Content-Type": "application/json",
+    "Authorization": "Bearer newT0k3n",
+    "Accept": "application/json",
+    "User-Agent": "snowflakeSQLAPI/1.0",
+    "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT",
+}
+
+GET_RESPONSE = {
+    "resultSetMetaData": {
+        "numRows": 10000,
+        "format": "jsonv2",
+        "rowType": [
+            {
+                "name": "SEQ8()",
+                "database": "",
+                "schema": "",
+                "table": "",
+                "scale": 0,
+                "precision": 19,
+            },
+            {
+                "name": "RANDSTR(1000, RANDOM())",
+                "database": "",
+                "schema": "",
+                "table": "",
+            },
+        ],
+        "partitionInfo": [
+            {
+                "rowCount": 12344,
+                "uncompressedSize": 14384873,
+            },
+            {"rowCount": 43746, "uncompressedSize": 43748274, 
"compressedSize": 746323},
+        ],
+    },
+    "code": "090001",
+    "statementStatusUrl": 
"/api/v2/statements/{handle}?requestId={id5}&partition=10",
+    "sqlState": "00000",
+    "statementHandle": "{handle}",
+    "message": "Statement executed successfully.",
+    "createdOn": 1620151693299,
+}
+
+
+def create_successful_response_mock(content):
+    """Create mock response for success state"""
+    response = mock.MagicMock()
+    response.json.return_value = content
+    response.status_code = 200
+    return response
+
+
+def create_post_side_effect(status_code=429):
+    """create mock response for post side effect"""
+    response = mock.MagicMock()
+    response.status_code = status_code
+    response.reason = "test"
+    response.raise_for_status.side_effect = 
requests.exceptions.HTTPError(response=response)
+    return response
+
+
+class TestSnowflakeSqlApiHook:
+    @pytest.mark.parametrize(
+        "sql,statement_count,expected_response, expected_query_ids",
+        [
+            (SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"]),
+            (SQL_MULTIPLE_STMTS, 4, {"statementHandles": ["uuid", "uuid1"]}, 
["uuid", "uuid1"]),
+        ],
+    )
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+    def test_execute_query(
+        self,
+        mock_get_header,
+        mock_conn_param,
+        mock_requests,
+        sql,
+        statement_count,
+        expected_response,
+        expected_query_ids,
+    ):
+        """Test execute_query method, run query by mocking post request method 
and return the query ids"""
+        mock_requests.codes.ok = 200
+        mock_requests.post.side_effect = [
+            create_successful_response_mock(expected_response),
+        ]
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+
+        hook = SnowflakeSqlApiHook("mock_conn_id")
+        query_ids = hook.execute_query(sql, statement_count)
+        assert query_ids == expected_query_ids
+
+    @pytest.mark.parametrize(
+        "sql,statement_count,expected_response, expected_query_ids",
+        [(SINGLE_STMT, 1, {"statementHandle": "uuid"}, ["uuid"])],
+    )
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+    def test_execute_query_exception_without_statement_handel(
+        self,
+        mock_get_header,
+        mock_conn_param,
+        mock_requests,
+        sql,
+        statement_count,
+        expected_response,
+        expected_query_ids,
+    ):
+        """
+        Test execute_query method by mocking the exception response and raise 
airflow exception
+        without statementHandle in the response
+        """
+        side_effect = create_post_side_effect()
+        mock_requests.post.side_effect = side_effect
+        hook = SnowflakeSqlApiHook("mock_conn_id")
+
+        with pytest.raises(AirflowException) as exception_info:
+            hook.execute_query(sql, statement_count)
+        assert exception_info
+
+    @pytest.mark.parametrize(
+        "query_ids",
+        [
+            (["uuid", "uuid1"]),
+        ],
+    )
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+        "get_request_url_header_params"
+    )
+    def test_check_query_output(self, mock_geturl_header_params, 
mock_requests, query_ids):
+        """Test check_query_output by passing query ids as params and mock 
get_request_url_header_params"""
+        req_id = uuid.uuid4()
+        params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+        mock_geturl_header_params.return_value = HEADERS, params, 
"/test/airflow/"
+        mock_requests.get.return_value.json.return_value = GET_RESPONSE
+        hook = SnowflakeSqlApiHook("mock_conn_id")
+        with mock.patch.object(hook.log, "info") as mock_log_info:
+            hook.check_query_output(query_ids)
+        mock_log_info.assert_called_with(GET_RESPONSE)
+
+    @pytest.mark.parametrize(
+        "query_ids",
+        [
+            (["uuid", "uuid1"]),
+        ],
+    )
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests.get")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+        "get_request_url_header_params"
+    )
+    def test_check_query_output_exception(self, mock_geturl_header_params, 
mock_requests, query_ids):
+        """
+        Test check_query_output by passing query ids as params and mock 
get_request_url_header_params
+        to raise airflow exception and mock with http error
+        """
+        req_id = uuid.uuid4()
+        params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+        mock_geturl_header_params.return_value = HEADERS, params, 
"/test/airflow/"
+        mock_resp = requests.models.Response()
+        mock_resp.status_code = 404
+        mock_requests.return_value = mock_resp
+        hook = SnowflakeSqlApiHook("mock_conn_id")
+        with mock.patch.object(hook.log, "error"):
+            with pytest.raises(AirflowException) as airflow_exception:
+                hook.check_query_output(query_ids)
+            assert airflow_exception
+
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_headers")
+    def test_get_request_url_header_params(self, mock_get_header, 
mock_conn_param):
+        """Test get_request_url_header_params by mocking _get_conn_params and 
get_headers"""
+        mock_conn_param.return_value = CONN_PARAMS
+        mock_get_header.return_value = HEADERS
+        hook = SnowflakeSqlApiHook("mock_conn_id")
+        header, params, url = hook.get_request_url_header_params("uuid")
+        assert header == HEADERS
+        assert url == 
"https://airflow.af_region.snowflakecomputing.com/api/v2/statements/uuid";
+
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_private_key")
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook._get_conn_params")
+    
@mock.patch("airflow.providers.snowflake.hooks.sql_api_generate_jwt.JWTGenerator.get_token")
+    def test_get_headers(self, mock_get_token, mock_conn_param, 
mock_private_key):
+        """Test get_headers method by mocking get_private_key and 
_get_conn_params method"""
+        mock_get_token.return_value = "newT0k3n"
+        mock_conn_param.return_value = CONN_PARAMS
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="mock_conn_id")
+        result = hook.get_headers()
+        assert result == HEADERS
+
+    @pytest.fixture()
+    def non_encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
+        """Encrypt the pem file from the path"""
+        key = rsa.generate_private_key(backend=default_backend(), 
public_exponent=65537, key_size=2048)
+        private_key = key.private_bytes(
+            serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, 
serialization.NoEncryption()
+        )
+        test_key_file = tmp_path / "test_key.pem"
+        test_key_file.write_bytes(private_key)
+        return test_key_file
+
+    @pytest.fixture()
+    def encrypted_temporary_private_key(self, tmp_path: Path) -> Path:
+        """Encrypt private key from the temp path"""
+        key = rsa.generate_private_key(backend=default_backend(), 
public_exponent=65537, key_size=2048)
+        private_key = key.private_bytes(
+            serialization.Encoding.PEM,
+            serialization.PrivateFormat.PKCS8,
+            
encryption_algorithm=serialization.BestAvailableEncryption(_PASSWORD.encode()),
+        )
+        test_key_file: Path = tmp_path / "test_key.p8"
+        test_key_file.write_bytes(private_key)
+        return test_key_file
+
+    def test_get_private_key_should_support_private_auth_in_connection(
+        self, encrypted_temporary_private_key: Path
+    ):
+        """Test get_private_key function with private_key_content in 
connection"""
+        connection_kwargs: Any = {
+            **BASE_CONNECTION_KWARGS,
+            "password": _PASSWORD,
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "private_key_content": 
str(encrypted_temporary_private_key.read_text()),
+            },
+        }
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ):
+            hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+            hook.get_private_key()
+            assert hook.private_key is not None
+
+    def test_get_private_key_raise_exception(self, 
encrypted_temporary_private_key: Path):
+        """
+        Test get_private_key function with private_key_content and 
private_key_file in connection
+        and raise airflow exception
+        """
+        connection_kwargs: Any = {
+            **BASE_CONNECTION_KWARGS,
+            "password": _PASSWORD,
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "private_key_content": 
str(encrypted_temporary_private_key.read_text()),
+                "private_key_file": str(encrypted_temporary_private_key),
+            },
+        }
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ), pytest.raises(
+            AirflowException,
+            match="The private_key_file and private_key_content extra fields 
are mutually "
+            "exclusive. Please remove one.",
+        ):
+            hook.get_private_key()
+
+    def test_get_private_key_should_support_private_auth_with_encrypted_key(
+        self, encrypted_temporary_private_key
+    ):
+        """Test get_private_key method by supporting for private auth 
encrypted_key"""
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "password": _PASSWORD,
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "private_key_file": str(encrypted_temporary_private_key),
+            },
+        }
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ):
+            hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+            hook.get_private_key()
+            assert hook.private_key is not None
+
+    def test_get_private_key_should_support_private_auth_with_unencrypted_key(
+        self,
+        non_encrypted_temporary_private_key,
+    ):
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "password": None,
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "private_key_file": str(non_encrypted_temporary_private_key),
+            },
+        }
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ):
+            hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+            hook.get_private_key()
+            assert hook.private_key is not None
+        connection_kwargs["password"] = ""
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ):
+            hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+            hook.get_private_key()
+            assert hook.private_key is not None
+        connection_kwargs["password"] = _PASSWORD
+        with unittest.mock.patch.dict(
+            "os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
+        ), pytest.raises(TypeError, match="Password was given but private key 
is not encrypted."):
+            
SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key()
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "status_code,response,expected_response",
+        [
+            (
+                200,
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statementHandle": "uuid",
+                },
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statement_handles": ["uuid"],
+                },
+            ),
+            (
+                200,
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statementHandles": ["uuid", "uuid1"],
+                },
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statement_handles": ["uuid", "uuid1"],
+                },
+            ),
+            (202, {}, {"status": "running", "message": "Query statements are 
still running"}),
+            (422, {"status": "error", "message": "test"}, {"status": "error", 
"message": "test"}),
+            (404, {"status": "error", "message": "test"}, {"status": "error", 
"message": "test"}),
+        ],
+    )
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+        "get_request_url_header_params"
+    )
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.requests")
+    def test_get_sql_api_query_status(
+        self, mock_requests, mock_geturl_header_params, status_code, response, 
expected_response
+    ):
+        """Test get_sql_api_query_status function by mocking the status, 
response and expected
+        response"""
+        req_id = uuid.uuid4()
+        params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+        mock_geturl_header_params.return_value = HEADERS, params, 
"/test/airflow/"
+
+        class MockResponse:
+            def __init__(self, status_code, data):
+                self.status_code = status_code
+                self.data = data
+
+            def json(self):
+                return self.data
+
+        mock_requests.get.return_value = MockResponse(status_code, response)
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+        assert hook.get_sql_api_query_status("uuid") == expected_response
diff --git a/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py 
b/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py
new file mode 100644
index 0000000000..5fa434d9b3
--- /dev/null
+++ b/tests/providers/snowflake/hooks/test_sql_api_generate_jwt.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from cryptography.hazmat.backends import default_backend as 
crypto_default_backend
+from cryptography.hazmat.primitives import serialization as 
crypto_serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+from airflow.providers.snowflake.hooks.sql_api_generate_jwt import JWTGenerator
+
+_PASSWORD = "snowflake42"
+
+key = rsa.generate_private_key(backend=crypto_default_backend(), 
public_exponent=65537, key_size=2048)
+
+private_key = key.private_bytes(
+    crypto_serialization.Encoding.PEM,
+    crypto_serialization.PrivateFormat.PKCS8,
+    crypto_serialization.NoEncryption(),
+)
+
+
+class TestJWTGenerator:
+    @pytest.mark.parametrize(
+        "account_name, expected_account_name",
+        [("test.us-east-1", "TEST"), ("test.global", "TEST.GLOBAL"), ("test", 
"TEST")],
+    )
+    def test_prepare_account_name_for_jwt(self, account_name, 
expected_account_name):
+        """
+        Test prepare_account_name_for_jwt by passing the account identifier and
+        get the proper account name in caps
+        """
+        jwt_generator = JWTGenerator(account_name, "test_user", private_key)
+        response = jwt_generator.prepare_account_name_for_jwt(account_name)
+        assert response == expected_account_name
+
+    def test_calculate_public_key_fingerprint(self):
+        """Asserting get_token and calculate_public_key_fingerprint by passing 
key and generating token"""
+        jwt_generator = JWTGenerator("test.us-east-1", "test_user", key)
+        assert jwt_generator.get_token()
diff --git a/tests/providers/snowflake/operators/test_snowflake.py 
b/tests/providers/snowflake/operators/test_snowflake.py
index dde2f9adde..8f32c6e62d 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -21,11 +21,13 @@ from unittest import mock
 
 import pytest
 
+from airflow.exceptions import AirflowException
 from airflow.models.dag import DAG
 from airflow.providers.snowflake.operators.snowflake import (
     SnowflakeCheckOperator,
     SnowflakeIntervalCheckOperator,
     SnowflakeOperator,
+    SnowflakeSqlApiOperator,
     SnowflakeValueCheckOperator,
 )
 from airflow.utils import timezone
@@ -35,6 +37,17 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
 DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
 TEST_DAG_ID = "unit_test_dag"
 
+TASK_ID = "snowflake_check"
+CONN_ID = "my_snowflake_conn"
+TEST_SQL = "select * from any;"
+
+SQL_MULTIPLE_STMTS = (
+    "create or replace table user_test (i int); insert into user_test (i) "
+    "values (200); insert into user_test (i) values (300); select i from 
user_test order by i;"
+)
+
+SINGLE_STMT = "select i from user_test order by i;"
+
 
 class TestSnowflakeOperator:
     def setup_method(self):
@@ -73,3 +86,59 @@ class TestSnowflakeCheckOperators:
         operator = operator_class(task_id="snowflake_check", 
snowflake_conn_id="snowflake_default", **kwargs)
         operator.get_db_hook()
         mock_get_db_hook.assert_called_once()
+
+
+class TestSnowflakeSqlApiOperator:
+    @pytest.fixture
+    def mock_execute_query(self):
+        with mock.patch(
+            
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.execute_query"
+        ) as execute_query:
+            yield execute_query
+
+    @pytest.fixture
+    def mock_get_sql_api_query_status(self):
+        with mock.patch(
+            
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.get_sql_api_query_status"
+        ) as get_sql_api_query_status:
+            yield get_sql_api_query_status
+
+    @pytest.fixture
+    def mock_check_query_output(self):
+        with mock.patch(
+            
"airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiHook.check_query_output"
+        ) as check_query_output:
+            yield check_query_output
+
+    def test_snowflake_sql_api_to_succeed_when_no_query_fails(
+        self, mock_execute_query, mock_get_sql_api_query_status, 
mock_check_query_output
+    ):
+        """Tests SnowflakeSqlApiOperator passed if poll_on_queries method 
gives no error"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id="snowflake_default",
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            do_xcom_push=False,
+        )
+        mock_execute_query.return_value = ["uuid1", "uuid2"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "success"}, 
{"status": "success"}]
+        operator.execute(context=None)
+
+    def test_snowflake_sql_api_to_fails_when_one_query_fails(
+        self, mock_execute_query, mock_get_sql_api_query_status
+    ):
+        """Tests SnowflakeSqlApiOperator passed if poll_on_queries method 
gives one or more error"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id="snowflake_default",
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            do_xcom_push=False,
+        )
+        mock_execute_query.return_value = ["uuid1", "uuid2"]
+        mock_get_sql_api_query_status.side_effect = [{"status": "error"}, 
{"status": "success"}]
+        with pytest.raises(AirflowException):
+            operator.execute(context=None)
diff --git a/tests/system/providers/snowflake/example_snowflake.py 
b/tests/system/providers/snowflake/example_snowflake.py
index ac174f7b59..007dcdb4ab 100644
--- a/tests/system/providers/snowflake/example_snowflake.py
+++ b/tests/system/providers/snowflake/example_snowflake.py
@@ -24,7 +24,7 @@ import os
 from datetime import datetime
 
 from airflow import DAG
-from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
+from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator, 
SnowflakeSqlApiOperator
 
 SNOWFLAKE_CONN_ID = "my_snowflake_conn"
 SNOWFLAKE_SAMPLE_TABLE = "sample_table"
@@ -72,6 +72,14 @@ with DAG(
 
     # [END howto_operator_snowflake]
 
+    # [START howto_snowflake_sql_api_operator]
+    snowflake_sql_api_op_sql_multiple_stmt = SnowflakeSqlApiOperator(
+        task_id="snowflake_op_sql_multiple_stmt",
+        sql=SQL_MULTIPLE_STMTS,
+        statement_count=len(SQL_LIST),
+    )
+    # [END howto_snowflake_sql_api_operator]
+
     (
         snowflake_op_sql_str
         >> [
@@ -79,6 +87,7 @@ with DAG(
             snowflake_op_sql_list,
             snowflake_op_template_file,
             snowflake_op_sql_multiple_stmts,
+            snowflake_sql_api_op_sql_multiple_stmt,
         ]
     )
 

Reply via email to