This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new c90e45a373 feat: make user agent customizable (#32506)
c90e45a373 is described below

commit c90e45a373b9b4f05e6f5b4cc06c4f98f0f0287b
Author: Ville Brofeldt <[email protected]>
AuthorDate: Wed Mar 5 16:33:24 2025 -0800

    feat: make user agent customizable (#32506)
---
 superset/config.py                             |  6 +++++-
 superset/constants.py                          |  2 +-
 superset/db_engine_specs/base.py               | 13 ++++++-------
 superset/db_engine_specs/bigquery.py           |  3 ++-
 superset/db_engine_specs/databricks.py         | 14 +++++++++-----
 superset/db_engine_specs/druid.py              |  6 +++++-
 superset/db_engine_specs/duckdb.py             | 11 +++++++----
 superset/db_engine_specs/hive.py               |  2 --
 superset/db_engine_specs/parseable.py          |  5 ++++-
 superset/db_engine_specs/postgres.py           |  6 ++++--
 superset/db_engine_specs/presto.py             |  1 -
 superset/db_engine_specs/snowflake.py          | 12 +++++++++---
 superset/db_engine_specs/sqlite.py             |  1 -
 superset/db_engine_specs/trino.py              | 17 +++++++++++------
 superset/models/core.py                        |  6 +++---
 superset/utils/core.py                         |  9 +++++++++
 tests/unit_tests/db_engine_specs/test_trino.py |  7 ++++---
 tests/unit_tests/utils/test_core.py            | 22 ++++++++++++++++++++++
 18 files changed, 101 insertions(+), 42 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index 500627063b..44e91c956c 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -58,7 +58,7 @@ from superset.stats_logger import DummyStatsLogger
 from superset.superset_typing import CacheConfig
 from superset.tasks.types import ExecutorType
 from superset.utils import core as utils
-from superset.utils.core import NO_TIME_RANGE, parse_boolean_string
+from superset.utils.core import NO_TIME_RANGE, parse_boolean_string, 
QuerySource
 from superset.utils.encrypt import SQLAlchemyUtilsAdapter
 from superset.utils.log import DBEventLogger
 from superset.utils.logging_configurator import DefaultLoggingConfigurator
@@ -595,6 +595,10 @@ DEFAULT_FEATURE_FLAGS.update(
     }
 )
 
+# This function can be overridden to customize the name of the user agent
+# triggering the query.
+USER_AGENT_FUNC: Callable[[Database, QuerySource | None], str] | None = None
+
 # This is merely a default.
 FEATURE_FLAGS: dict[str, bool] = {}
 
diff --git a/superset/constants.py b/superset/constants.py
index b55048463e..f60b79a961 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -22,7 +22,7 @@ from enum import Enum
 
 from superset.utils.backports import StrEnum
 
-USER_AGENT = "Apache Superset"
+DEFAULT_USER_AGENT = "Apache Superset"
 
 NULL_STRING = "<NULL>"
 EMPTY_STRING = "<empty string>"
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 248f8bd0a7..084891894b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -73,7 +73,7 @@ from superset.superset_typing import (
     SQLAColumnType,
 )
 from superset.utils import core as utils, json
-from superset.utils.core import ColumnSpec, GenericDataType
+from superset.utils.core import ColumnSpec, GenericDataType, QuerySource
 from superset.utils.hashing import md5_sha_from_str
 from superset.utils.json import redact_sensitive, reveal_sensitive
 from superset.utils.network import is_hostname_valid, is_port_open
@@ -1023,7 +1023,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         For instance special column like `__time` for Druid can be
         set to is_dttm=True. Note that this only gets called when new
         columns are detected/created"""
-        # TODO: Fix circular import caused by importing TableColumn
 
     @classmethod
     def epoch_to_dttm(cls) -> str:
@@ -1128,7 +1127,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param database: Database instance
         :return: SQL query with limit clause
         """
-        # TODO: Fix circular import caused by importing Database
         if cls.limit_method == LimitMethod.WRAP_SQL:
             sql = sql.strip("\t\n ;")
             qry = (
@@ -1321,7 +1319,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         The flow works without this method doing anything, but it allows
         for handling the cursor and updating progress information in the
         query object"""
-        # TODO: Fix circular import error caused by importing sql_lab.Query
 
     @classmethod
     # pylint: disable=consider-using-transaction
@@ -1637,7 +1634,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :return: SqlAlchemy query with additional where clause referencing the 
latest
         partition
         """
-        # TODO: Fix circular import caused by importing Database, TableColumn
         return None
 
     @classmethod
@@ -1768,7 +1764,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param sql: SQL query with possibly multiple statements
         :param source: Source of the query (eg, "sql_lab")
         """
-        extra = database.get_extra() or {}
+        extra = database.get_extra(source) or {}
         if not cls.get_allow_cost_estimate(extra):
             raise Exception(  # pylint: disable=broad-exception-raised
                 "Database does not support cost estimation"
@@ -2019,12 +2015,15 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         return None
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         Some databases require adding elements to connection parameters,
         like passing certificates to `extra`. This can be done here.
 
         :param database: database instance from which to extract extras
+        :param source: in which context is the connection needed
         :raises CertificateException: If certificate is not valid/unparseable
         """
         extra: dict[str, Any] = {}
diff --git a/superset/db_engine_specs/bigquery.py 
b/superset/db_engine_specs/bigquery.py
index ef0ce31b58..c6c57b1624 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -459,7 +459,7 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: 
disable=too-many-public-met
         :param sql: SQL query with possibly multiple statements
         :param source: Source of the query (eg, "sql_lab")
         """
-        extra = database.get_extra() or {}
+        extra = database.get_extra(source) or {}
         if not cls.get_allow_cost_estimate(extra):
             raise SupersetException("Database does not support cost 
estimation")
 
@@ -469,6 +469,7 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: 
disable=too-many-public-met
             database,
             catalog=catalog,
             schema=schema,
+            source=source,
         ) as engine:
             client = cls._get_client(engine, database)
             return [
diff --git a/superset/db_engine_specs/databricks.py 
b/superset/db_engine_specs/databricks.py
index 1805682607..7574989abf 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -27,12 +27,13 @@ from marshmallow.validate import Range
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
 from superset.db_engine_specs.hive import HiveEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.utils import json
+from superset.utils.core import get_user_agent, QuerySource
 from superset.utils.network import is_hostname_valid, is_port_open
 
 if TYPE_CHECKING:
@@ -198,17 +199,20 @@ class 
DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngine
     }
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         Add a user agent to be used in the requests.
         Trim whitespace from connect_args to avoid databricks driver errors
         """
-        extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+        extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database, 
source)
         engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
         connect_args: dict[str, Any] = 
engine_params.setdefault("connect_args", {})
 
-        connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
-        connect_args.setdefault("_user_agent_entry", USER_AGENT)
+        user_agent = get_user_agent(database, source)
+        connect_args.setdefault("http_headers", [("User-Agent", user_agent)])
+        connect_args.setdefault("_user_agent_entry", user_agent)
 
         # trim whitespace from http_path to avoid databricks errors on 
connecting
         if http_path := connect_args.get("http_path"):
diff --git a/superset/db_engine_specs/druid.py 
b/superset/db_engine_specs/druid.py
index 9dc366aa5b..8692e77c1e 100644
--- a/superset/db_engine_specs/druid.py
+++ b/superset/db_engine_specs/druid.py
@@ -29,6 +29,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
 from superset.exceptions import SupersetException
 from superset.utils import core as utils, json
+from superset.utils.core import QuerySource
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -78,11 +79,14 @@ class DruidEngineSpec(BaseEngineSpec):
             orm_col.is_dttm = True
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         For Druid, the path to a SSL certificate is placed in `connect_args`.
 
         :param database: database instance from which to extract extras
+        :param source: in which context is the connection needed
         :raises CertificateException: If certificate is not valid/unparseable
         :raises SupersetException: If database extra json payload is 
unparseable
         """
diff --git a/superset/db_engine_specs/duckdb.py 
b/superset/db_engine_specs/duckdb.py
index 8fd5ab62b8..a5eba3642f 100644
--- a/superset/db_engine_specs/duckdb.py
+++ b/superset/db_engine_specs/duckdb.py
@@ -30,13 +30,13 @@ from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 
 from superset.config import VERSION_STRING
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.utils.core import get_user_agent, QuerySource
 
 if TYPE_CHECKING:
-    # prevent circular imports
     from superset.models.core import Database
 
 
@@ -237,7 +237,9 @@ class DuckDBEngineSpec(DuckDBParametersMixin, 
BaseEngineSpec):
         return set(inspector.get_table_names(schema))
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         Add a user agent to be used in the requests.
         """
@@ -247,7 +249,8 @@ class DuckDBEngineSpec(DuckDBParametersMixin, 
BaseEngineSpec):
         config: dict[str, Any] = connect_args.setdefault("config", {})
         custom_user_agent = config.pop("custom_user_agent", "")
         delim = " " if custom_user_agent else ""
-        user_agent = USER_AGENT.replace(" ", "-").lower()
+        user_agent = get_user_agent(database, source)
+        user_agent = user_agent.replace(" ", "-").lower()
         user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}"
         config.setdefault("custom_user_agent", user_agent)
 
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 9b7afab1cc..0a9b817804 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -49,8 +49,6 @@ from superset.sql_parse import Table
 from superset.superset_typing import ResultSetColumnType
 
 if TYPE_CHECKING:
-    # prevent circular imports
-
     from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
diff --git a/superset/db_engine_specs/parseable.py 
b/superset/db_engine_specs/parseable.py
index dbca211627..9be780c949 100644
--- a/superset/db_engine_specs/parseable.py
+++ b/superset/db_engine_specs/parseable.py
@@ -23,6 +23,7 @@ from sqlalchemy import types
 
 from superset.constants import TimeGrain
 from superset.db_engine_specs.base import BaseEngineSpec
+from superset.utils.core import QuerySource
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -73,7 +74,9 @@ class ParseableEngineSpec(BaseEngineSpec):
             orm_col.is_dttm = True
 
     @classmethod
-    def get_extra_params(cls, database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        cls, database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """Additional parameters for Parseable connections."""
         return {
             "engine_params": {
diff --git a/superset/db_engine_specs/postgres.py 
b/superset/db_engine_specs/postgres.py
index a2e6a3fe12..16117062b0 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -37,7 +37,7 @@ from superset.exceptions import SupersetException, 
SupersetSecurityException
 from superset.models.sql_lab import Query
 from superset.sql.parse import SQLScript
 from superset.utils import core as utils, json
-from superset.utils.core import GenericDataType
+from superset.utils.core import GenericDataType, QuerySource
 
 if TYPE_CHECKING:
     from superset.models.core import Database  # pragma: no cover
@@ -411,7 +411,9 @@ WHERE datistemplate = false;
         )
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         For Postgres, the path to a SSL certificate is placed in 
`connect_args`.
 
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index e351c5b307..c3c8f61829 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -63,7 +63,6 @@ from superset.utils import core as utils, json
 from superset.utils.core import GenericDataType
 
 if TYPE_CHECKING:
-    # prevent circular imports
     from superset.models.core import Database
     from superset.sql_parse import Table
 
diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index a99f1e641d..5c9c42ea76 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -14,6 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from __future__ import annotations
+
 import logging
 import re
 from datetime import datetime
@@ -32,13 +34,14 @@ from sqlalchemy import types
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
 
-from superset.constants import TimeGrain, USER_AGENT
+from superset.constants import TimeGrain
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
 from superset.utils import json
+from superset.utils.core import get_user_agent, QuerySource
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -130,15 +133,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     }
 
     @staticmethod
-    def get_extra_params(database: "Database") -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         Add a user agent to be used in the requests.
         """
         extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
         engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
         connect_args: dict[str, Any] = 
engine_params.setdefault("connect_args", {})
+        user_agent = get_user_agent(database, source)
 
-        connect_args.setdefault("application", USER_AGENT)
+        connect_args.setdefault("application", user_agent)
 
         return extra
 
diff --git a/superset/db_engine_specs/sqlite.py 
b/superset/db_engine_specs/sqlite.py
index 50451e9bf5..a7e9a96e33 100644
--- a/superset/db_engine_specs/sqlite.py
+++ b/superset/db_engine_specs/sqlite.py
@@ -31,7 +31,6 @@ from superset.db_engine_specs.base import BaseEngineSpec
 from superset.errors import SupersetErrorType
 
 if TYPE_CHECKING:
-    # prevent circular imports
     from superset.models.core import Database
 
 
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index beb7c0a604..79fdef19bf 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -29,7 +29,7 @@ from sqlalchemy.engine.url import URL
 from sqlalchemy.exc import NoSuchTableError
 
 from superset import db
-from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, 
USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, 
convert_inspector_columns
 from superset.db_engine_specs.exceptions import (
@@ -42,7 +42,8 @@ from superset.db_engine_specs.presto import 
PrestoBaseEngineSpec
 from superset.models.sql_lab import Query
 from superset.sql_parse import Table
 from superset.superset_typing import ResultSetColumnType
-from superset.utils import core as utils, json
+from superset.utils import json
+from superset.utils.core import create_ssl_cert_file, get_user_agent, 
QuerySource
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -330,23 +331,27 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         return True
 
     @staticmethod
-    def get_extra_params(database: Database) -> dict[str, Any]:
+    def get_extra_params(
+        database: Database, source: QuerySource | None = None
+    ) -> dict[str, Any]:
         """
         Some databases require adding elements to connection parameters,
         like passing certificates to `extra`. This can be done here.
 
         :param database: database instance from which to extract extras
+        :param source: in which context is the connection needed
         :raises CertificateException: If certificate is not valid/unparseable
         """
-        extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+        extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database, 
source)
         engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
         connect_args: dict[str, Any] = 
engine_params.setdefault("connect_args", {})
+        user_agent = get_user_agent(database, source)
 
-        connect_args.setdefault("source", USER_AGENT)
+        connect_args.setdefault("source", user_agent)
 
         if database.server_cert:
             connect_args["http_scheme"] = "https"
-            connect_args["verify"] = 
utils.create_ssl_cert_file(database.server_cert)
+            connect_args["verify"] = create_ssl_cert_file(database.server_cert)
 
         return extra
 
diff --git a/superset/models/core.py b/superset/models/core.py
index dae9e8a0ec..dc8df7ca13 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -478,7 +478,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         )
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
-        extra = self.get_extra()
+        extra = self.get_extra(source)
         params = extra.get("engine_params", {})
         if nullpool:
             params["poolclass"] = NullPool
@@ -955,8 +955,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         """
         return self.db_engine_spec.get_time_grains()
 
-    def get_extra(self) -> dict[str, Any]:
-        return self.db_engine_spec.get_extra_params(self)
+    def get_extra(self, source: utils.QuerySource | None = None) -> dict[str, 
Any]:
+        return self.db_engine_spec.get_extra_params(self, source)
 
     def get_encrypted_extra(self) -> dict[str, Any]:
         encrypted_extra = {}
diff --git a/superset/utils/core.py b/superset/utils/core.py
index e3e1a87dec..2aaf912229 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -74,6 +74,7 @@ from sqlalchemy.types import TypeEngine
 from typing_extensions import TypeGuard
 
 from superset.constants import (
+    DEFAULT_USER_AGENT,
     EXTRA_FORM_DATA_APPEND_KEYS,
     EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS,
     EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS,
@@ -103,6 +104,7 @@ from superset.utils.hashing import md5_sha_from_dict, 
md5_sha_from_str
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import BaseDatasource, TableColumn
+    from superset.models.core import Database
     from superset.models.sql_lab import Query
 
 logging.getLogger("MARKDOWN").setLevel(logging.INFO)
@@ -1795,3 +1797,10 @@ def to_int(v: Any, value_if_invalid: int = 0) -> int:
         return int(v)
     except (ValueError, TypeError):
         return value_if_invalid
+
+
+def get_user_agent(database: Database, source: QuerySource | None) -> str:
+    if user_agent_func := current_app.config["USER_AGENT_FUNC"]:
+        return user_agent_func(database, source)
+
+    return DEFAULT_USER_AGENT
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py 
b/tests/unit_tests/db_engine_specs/test_trino.py
index cbf3bdd5de..31e70494b7 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -37,7 +37,7 @@ from trino.sqlalchemy import datatype
 from trino.sqlalchemy.dialect import TrinoDialect
 
 import superset.config
-from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, 
USER_AGENT
+from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
 from superset.db_engine_specs.exceptions import (
     SupersetDBAPIConnectionError,
     SupersetDBAPIDatabaseError,
@@ -81,7 +81,7 @@ def _assert_columns_equal(actual_cols, expected_cols) -> None:
 @pytest.mark.parametrize(
     "extra,expected",
     [
-        ({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}),
+        ({}, {"engine_params": {"connect_args": {"source": "Apache 
Superset"}}}),
         (
             {
                 "first": 1,
@@ -110,7 +110,7 @@ def test_get_extra_params(extra: dict[str, Any], expected: 
dict[str, Any]) -> No
     assert TrinoEngineSpec.get_extra_params(database) == expected
 
 
-@patch("superset.utils.core.create_ssl_cert_file")
+@patch("superset.db_engine_specs.trino.create_ssl_cert_file")
 def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> 
None:
     from superset.db_engine_specs.trino import TrinoEngineSpec
 
@@ -118,6 +118,7 @@ def 
test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> N
 
     database.extra = json.dumps({})
     database.server_cert = "TEST_CERT"
+    database.db_engine_spec = TrinoEngineSpec
     mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
     extra = TrinoEngineSpec.get_extra_params(database)
 
diff --git a/tests/unit_tests/utils/test_core.py 
b/tests/unit_tests/utils/test_core.py
index cdeb58a50b..74b9373392 100644
--- a/tests/unit_tests/utils/test_core.py
+++ b/tests/unit_tests/utils/test_core.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
 
 import pandas as pd
 import pytest
+from pytest_mock import MockerFixture
 
 from superset.exceptions import SupersetException
 from superset.utils.core import (
@@ -30,10 +31,12 @@ from superset.utils.core import (
     generic_find_constraint_name,
     generic_find_fk_constraint_name,
     get_datasource_full_name,
+    get_user_agent,
     is_test,
     normalize_dttm_col,
     parse_boolean_string,
     QueryObjectFilterClause,
+    QuerySource,
     remove_extra_adhoc_filters,
 )
 
@@ -396,3 +399,22 @@ def test_get_datasource_full_name():
         get_datasource_full_name("db", "table", "catalog", None)
         == "[db].[catalog].[table]"
     )
+
+
+def test_get_user_agent(mocker: MockerFixture) -> None:
+    database_mock = mocker.MagicMock()
+    database_mock.database_name = "mydb"
+
+    current_app_mock = mocker.patch("superset.utils.core.current_app")
+    current_app_mock.config = {"USER_AGENT_FUNC": None}
+
+    assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "Apache 
Superset", (
+        "The default user agent should be returned"
+    )
+    current_app_mock.config["USER_AGENT_FUNC"] = (
+        lambda database, source: f"{database.database_name} {source.name}"
+    )
+
+    assert get_user_agent(database_mock, QuerySource.DASHBOARD) == "mydb 
DASHBOARD", (
+        "the custom user agent function result should have been returned"
+    )

Reply via email to