This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new c90e45a373 feat: make user agent customizable (#32506)
c90e45a373 is described below
commit c90e45a373b9b4f05e6f5b4cc06c4f98f0f0287b
Author: Ville Brofeldt <[email protected]>
AuthorDate: Wed Mar 5 16:33:24 2025 -0800
feat: make user agent customizable (#32506)
---
superset/config.py | 6 +++++-
superset/constants.py | 2 +-
superset/db_engine_specs/base.py | 13 ++++++-------
superset/db_engine_specs/bigquery.py | 3 ++-
superset/db_engine_specs/databricks.py | 14 +++++++++-----
superset/db_engine_specs/druid.py | 6 +++++-
superset/db_engine_specs/duckdb.py | 11 +++++++----
superset/db_engine_specs/hive.py | 2 --
superset/db_engine_specs/parseable.py | 5 ++++-
superset/db_engine_specs/postgres.py | 6 ++++--
superset/db_engine_specs/presto.py | 1 -
superset/db_engine_specs/snowflake.py | 12 +++++++++---
superset/db_engine_specs/sqlite.py | 1 -
superset/db_engine_specs/trino.py | 17 +++++++++++------
superset/models/core.py | 6 +++---
superset/utils/core.py | 9 +++++++++
tests/unit_tests/db_engine_specs/test_trino.py | 7 ++++---
tests/unit_tests/utils/test_core.py | 22 ++++++++++++++++++++++
18 files changed, 101 insertions(+), 42 deletions(-)
diff --git a/superset/config.py b/superset/config.py
index 500627063b..44e91c956c 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -58,7 +58,7 @@ from superset.stats_logger import DummyStatsLogger
from superset.superset_typing import CacheConfig
from superset.tasks.types import ExecutorType
from superset.utils import core as utils
-from superset.utils.core import NO_TIME_RANGE, parse_boolean_string
+from superset.utils.core import NO_TIME_RANGE, parse_boolean_string,
QuerySource
from superset.utils.encrypt import SQLAlchemyUtilsAdapter
from superset.utils.log import DBEventLogger
from superset.utils.logging_configurator import DefaultLoggingConfigurator
@@ -595,6 +595,10 @@ DEFAULT_FEATURE_FLAGS.update(
}
)
+# This function can be overridden to customize the name of the user agent
+# triggering the query.
+USER_AGENT_FUNC: Callable[[Database, QuerySource | None], str] | None = None
+
# This is merely a default.
FEATURE_FLAGS: dict[str, bool] = {}
diff --git a/superset/constants.py b/superset/constants.py
index b55048463e..f60b79a961 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -22,7 +22,7 @@ from enum import Enum
from superset.utils.backports import StrEnum
-USER_AGENT = "Apache Superset"
+DEFAULT_USER_AGENT = "Apache Superset"
NULL_STRING = "<NULL>"
EMPTY_STRING = "<empty string>"
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 248f8bd0a7..084891894b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -73,7 +73,7 @@ from superset.superset_typing import (
SQLAColumnType,
)
from superset.utils import core as utils, json
-from superset.utils.core import ColumnSpec, GenericDataType
+from superset.utils.core import ColumnSpec, GenericDataType, QuerySource
from superset.utils.hashing import md5_sha_from_str
from superset.utils.json import redact_sensitive, reveal_sensitive
from superset.utils.network import is_hostname_valid, is_port_open
@@ -1023,7 +1023,6 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
For instance special column like `__time` for Druid can be
set to is_dttm=True. Note that this only gets called when new
columns are detected/created"""
- # TODO: Fix circular import caused by importing TableColumn
@classmethod
def epoch_to_dttm(cls) -> str:
@@ -1128,7 +1127,6 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
:param database: Database instance
:return: SQL query with limit clause
"""
- # TODO: Fix circular import caused by importing Database
if cls.limit_method == LimitMethod.WRAP_SQL:
sql = sql.strip("\t\n ;")
qry = (
@@ -1321,7 +1319,6 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
The flow works without this method doing anything, but it allows
for handling the cursor and updating progress information in the
query object"""
- # TODO: Fix circular import error caused by importing sql_lab.Query
@classmethod
# pylint: disable=consider-using-transaction
@@ -1637,7 +1634,6 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
:return: SqlAlchemy query with additional where clause referencing the
latest
partition
"""
- # TODO: Fix circular import caused by importing Database, TableColumn
return None
@classmethod
@@ -1768,7 +1764,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
"""
- extra = database.get_extra() or {}
+ extra = database.get_extra(source) or {}
if not cls.get_allow_cost_estimate(extra):
raise Exception( # pylint: disable=broad-exception-raised
"Database does not support cost estimation"
@@ -2019,12 +2015,15 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
return None
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
:param database: database instance from which to extract extras
+ :param source: in which context is the connection needed
:raises CertificateException: If certificate is not valid/unparseable
"""
extra: dict[str, Any] = {}
diff --git a/superset/db_engine_specs/bigquery.py
b/superset/db_engine_specs/bigquery.py
index ef0ce31b58..c6c57b1624 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -459,7 +459,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint:
disable=too-many-public-met
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
"""
- extra = database.get_extra() or {}
+ extra = database.get_extra(source) or {}
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost
estimation")
@@ -469,6 +469,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint:
disable=too-many-public-met
database,
catalog=catalog,
schema=schema,
+ source=source,
) as engine:
client = cls._get_client(engine, database)
return [
diff --git a/superset/db_engine_specs/databricks.py
b/superset/db_engine_specs/databricks.py
index 1805682607..7574989abf 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -27,12 +27,13 @@ from marshmallow.validate import Range
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.utils import json
+from superset.utils.core import get_user_agent, QuerySource
from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
@@ -198,17 +199,20 @@ class
DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
}
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
Trim whitespace from connect_args to avoid databricks driver errors
"""
- extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database,
source)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] =
engine_params.setdefault("connect_args", {})
- connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
- connect_args.setdefault("_user_agent_entry", USER_AGENT)
+ user_agent = get_user_agent(database, source)
+ connect_args.setdefault("http_headers", [("User-Agent", user_agent)])
+ connect_args.setdefault("_user_agent_entry", user_agent)
# trim whitespace from http_path to avoid databricks errors on
connecting
if http_path := connect_args.get("http_path"):
diff --git a/superset/db_engine_specs/druid.py
b/superset/db_engine_specs/druid.py
index 9dc366aa5b..8692e77c1e 100644
--- a/superset/db_engine_specs/druid.py
+++ b/superset/db_engine_specs/druid.py
@@ -29,6 +29,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.exceptions import SupersetException
from superset.utils import core as utils, json
+from superset.utils.core import QuerySource
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -78,11 +79,14 @@ class DruidEngineSpec(BaseEngineSpec):
orm_col.is_dttm = True
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
:param database: database instance from which to extract extras
+ :param source: in which context is the connection needed
:raises CertificateException: If certificate is not valid/unparseable
:raises SupersetException: If database extra json payload is
unparseable
"""
diff --git a/superset/db_engine_specs/duckdb.py
b/superset/db_engine_specs/duckdb.py
index 8fd5ab62b8..a5eba3642f 100644
--- a/superset/db_engine_specs/duckdb.py
+++ b/superset/db_engine_specs/duckdb.py
@@ -30,13 +30,13 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from superset.config import VERSION_STRING
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.utils.core import get_user_agent, QuerySource
if TYPE_CHECKING:
- # prevent circular imports
from superset.models.core import Database
@@ -237,7 +237,9 @@ class DuckDBEngineSpec(DuckDBParametersMixin,
BaseEngineSpec):
return set(inspector.get_table_names(schema))
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
@@ -247,7 +249,8 @@ class DuckDBEngineSpec(DuckDBParametersMixin,
BaseEngineSpec):
config: dict[str, Any] = connect_args.setdefault("config", {})
custom_user_agent = config.pop("custom_user_agent", "")
delim = " " if custom_user_agent else ""
- user_agent = USER_AGENT.replace(" ", "-").lower()
+ user_agent = get_user_agent(database, source)
+ user_agent = user_agent.replace(" ", "-").lower()
user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}"
config.setdefault("custom_user_agent", user_agent)
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 9b7afab1cc..0a9b817804 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -49,8 +49,6 @@ from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
- # prevent circular imports
-
from superset.models.core import Database
logger = logging.getLogger(__name__)
diff --git a/superset/db_engine_specs/parseable.py
b/superset/db_engine_specs/parseable.py
index dbca211627..9be780c949 100644
--- a/superset/db_engine_specs/parseable.py
+++ b/superset/db_engine_specs/parseable.py
@@ -23,6 +23,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
+from superset.utils.core import QuerySource
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -73,7 +74,9 @@ class ParseableEngineSpec(BaseEngineSpec):
orm_col.is_dttm = True
@classmethod
- def get_extra_params(cls, database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ cls, database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""Additional parameters for Parseable connections."""
return {
"engine_params": {
diff --git a/superset/db_engine_specs/postgres.py
b/superset/db_engine_specs/postgres.py
index a2e6a3fe12..16117062b0 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -37,7 +37,7 @@ from superset.exceptions import SupersetException,
SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql.parse import SQLScript
from superset.utils import core as utils, json
-from superset.utils.core import GenericDataType
+from superset.utils.core import GenericDataType, QuerySource
if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover
@@ -411,7 +411,9 @@ WHERE datistemplate = false;
)
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in
`connect_args`.
diff --git a/superset/db_engine_specs/presto.py
b/superset/db_engine_specs/presto.py
index e351c5b307..c3c8f61829 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -63,7 +63,6 @@ from superset.utils import core as utils, json
from superset.utils.core import GenericDataType
if TYPE_CHECKING:
- # prevent circular imports
from superset.models.core import Database
from superset.sql_parse import Table
diff --git a/superset/db_engine_specs/snowflake.py
b/superset/db_engine_specs/snowflake.py
index a99f1e641d..5c9c42ea76 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
import re
from datetime import datetime
@@ -32,13 +34,14 @@ from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import json
+from superset.utils.core import get_user_agent, QuerySource
if TYPE_CHECKING:
from superset.models.core import Database
@@ -130,15 +133,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
}
@staticmethod
- def get_extra_params(database: "Database") -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] =
engine_params.setdefault("connect_args", {})
+ user_agent = get_user_agent(database, source)
- connect_args.setdefault("application", USER_AGENT)
+ connect_args.setdefault("application", user_agent)
return extra
diff --git a/superset/db_engine_specs/sqlite.py
b/superset/db_engine_specs/sqlite.py
index 50451e9bf5..a7e9a96e33 100644
--- a/superset/db_engine_specs/sqlite.py
+++ b/superset/db_engine_specs/sqlite.py
@@ -31,7 +31,6 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
if TYPE_CHECKING:
- # prevent circular imports
from superset.models.core import Database
diff --git a/superset/db_engine_specs/trino.py
b/superset/db_engine_specs/trino.py
index beb7c0a604..79fdef19bf 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -29,7 +29,7 @@ from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from superset import db
-from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY,
USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec,
convert_inspector_columns
from superset.db_engine_specs.exceptions import (
@@ -42,7 +42,8 @@ from superset.db_engine_specs.presto import
PrestoBaseEngineSpec
from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
-from superset.utils import core as utils, json
+from superset.utils import json
+from superset.utils.core import create_ssl_cert_file, get_user_agent,
QuerySource
if TYPE_CHECKING:
from superset.models.core import Database
@@ -330,23 +331,27 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return True
@staticmethod
- def get_extra_params(database: Database) -> dict[str, Any]:
+ def get_extra_params(
+ database: Database, source: QuerySource | None = None
+ ) -> dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
:param database: database instance from which to extract extras
+ :param source: in which context is the connection needed
:raises CertificateException: If certificate is not valid/unparseable
"""
- extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database,
source)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] =
engine_params.setdefault("connect_args", {})
+ user_agent = get_user_agent(database, source)
- connect_args.setdefault("source", USER_AGENT)
+ connect_args.setdefault("source", user_agent)
if database.server_cert:
connect_args["http_scheme"] = "https"
- connect_args["verify"] =
utils.create_ssl_cert_file(database.server_cert)
+ connect_args["verify"] = create_ssl_cert_file(database.server_cert)
return extra
diff --git a/superset/models/core.py b/superset/models/core.py
index dae9e8a0ec..dc8df7ca13 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -478,7 +478,7 @@ class Database(Model, AuditMixinNullable,
ImportExportMixin): # pylint: disable
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
- extra = self.get_extra()
+ extra = self.get_extra(source)
params = extra.get("engine_params", {})
if nullpool:
params["poolclass"] = NullPool
@@ -955,8 +955,8 @@ class Database(Model, AuditMixinNullable,
ImportExportMixin): # pylint: disable
"""
return self.db_engine_spec.get_time_grains()
- def get_extra(self) -> dict[str, Any]:
- return self.db_engine_spec.get_extra_params(self)
+ def get_extra(self, source: utils.QuerySource | None = None) -> dict[str,
Any]:
+ return self.db_engine_spec.get_extra_params(self, source)
def get_encrypted_extra(self) -> dict[str, Any]:
encrypted_extra = {}
diff --git a/superset/utils/core.py b/superset/utils/core.py
index e3e1a87dec..2aaf912229 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -74,6 +74,7 @@ from sqlalchemy.types import TypeEngine
from typing_extensions import TypeGuard
from superset.constants import (
+ DEFAULT_USER_AGENT,
EXTRA_FORM_DATA_APPEND_KEYS,
EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS,
EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS,
@@ -103,6 +104,7 @@ from superset.utils.hashing import md5_sha_from_dict,
md5_sha_from_str
if TYPE_CHECKING:
from superset.connectors.sqla.models import BaseDatasource, TableColumn
+ from superset.models.core import Database
from superset.models.sql_lab import Query
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
@@ -1795,3 +1797,10 @@ def to_int(v: Any, value_if_invalid: int = 0) -> int:
return int(v)
except (ValueError, TypeError):
return value_if_invalid
+
+
+def get_user_agent(database: Database, source: QuerySource | None) -> str:
+ if user_agent_func := current_app.config["USER_AGENT_FUNC"]:
+ return user_agent_func(database, source)
+
+ return DEFAULT_USER_AGENT
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py
b/tests/unit_tests/db_engine_specs/test_trino.py
index cbf3bdd5de..31e70494b7 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -37,7 +37,7 @@ from trino.sqlalchemy import datatype
from trino.sqlalchemy.dialect import TrinoDialect
import superset.config
-from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY,
USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.db_engine_specs.exceptions import (
SupersetDBAPIConnectionError,
SupersetDBAPIDatabaseError,
@@ -81,7 +81,7 @@ def _assert_columns_equal(actual_cols, expected_cols) -> None:
@pytest.mark.parametrize(
"extra,expected",
[
- ({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}),
+ ({}, {"engine_params": {"connect_args": {"source": "Apache
Superset"}}}),
(
{
"first": 1,
@@ -110,7 +110,7 @@ def test_get_extra_params(extra: dict[str, Any], expected:
dict[str, Any]) -> No
assert TrinoEngineSpec.get_extra_params(database) == expected
-@patch("superset.utils.core.create_ssl_cert_file")
+@patch("superset.db_engine_specs.trino.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) ->
None:
from superset.db_engine_specs.trino import TrinoEngineSpec
@@ -118,6 +118,7 @@ def
test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> N
database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
+ database.db_engine_spec = TrinoEngineSpec
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)
diff --git a/tests/unit_tests/utils/test_core.py
b/tests/unit_tests/utils/test_core.py
index cdeb58a50b..74b9373392 100644
--- a/tests/unit_tests/utils/test_core.py
+++ b/tests/unit_tests/utils/test_core.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
+from pytest_mock import MockerFixture
from superset.exceptions import SupersetException
from superset.utils.core import (
@@ -30,10 +31,12 @@ from superset.utils.core import (
generic_find_constraint_name,
generic_find_fk_constraint_name,
get_datasource_full_name,
+ get_user_agent,
is_test,
normalize_dttm_col,
parse_boolean_string,
QueryObjectFilterClause,
+ QuerySource,
remove_extra_adhoc_filters,
)
@@ -396,3 +399,22 @@ def test_get_datasource_full_name():
get_datasource_full_name("db", "table", "catalog", None)
== "[db].[catalog].[table]"
)
+
+
+def test_get_user_agent(mocker: MockerFixture) -> None:
+ database_mock = mocker.MagicMock()
+ database_mock.database_name = "mydb"
+
+ current_app_mock = mocker.patch("superset.utils.core.current_app")
+ current_app_mock.config = {"USER_AGENT_FUNC": None}
+
+ assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "Apache
Superset", (
+ "The default user agent should be returned"
+ )
+ current_app_mock.config["USER_AGENT_FUNC"] = (
+ lambda database, source: f"{database.database_name} {source.name}"
+ )
+
+ assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "mydb
DASHBOARD", (
+ "the custom user agent function result should have been returned"
+ )