This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 307ebeaa19 chore(Databricks): New Databricks driver (#28393)
307ebeaa19 is described below

commit 307ebeaa19941fb31e5f1296d6c7cabca85f8f0d
Author: Vitor Avila <[email protected]>
AuthorDate: Thu May 9 15:58:03 2024 -0300

    chore(Databricks): New Databricks driver (#28393)
---
 .../DatabaseConnectionForm/CommonParameters.tsx    |  63 +++++
 .../DatabaseModal/DatabaseConnectionForm/index.tsx |   9 +
 .../src/features/databases/DatabaseModal/index.tsx |  16 +-
 superset-frontend/src/features/databases/types.ts  |  15 +
 superset/db_engine_specs/databricks.py             | 315 +++++++++++++++------
 .../unit_tests/db_engine_specs/test_databricks.py  |   8 +-
 6 files changed, 330 insertions(+), 96 deletions(-)

diff --git 
a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
 
b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
index 529fc18419..4d864ba11d 100644
--- 
a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
+++ 
b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
@@ -116,6 +116,69 @@ export const databaseField = ({
     helpText={t('Copy the name of the database you are trying to connect to.')}
   />
 );
+export const defaultCatalogField = ({
+  required,
+  changeMethods,
+  getValidation,
+  validationErrors,
+  db,
+}: FieldPropTypes) => (
+  <ValidatedInput
+    id="default_catalog"
+    name="default_catalog"
+    required={required}
+    value={db?.parameters?.default_catalog}
+    validationMethods={{ onBlur: getValidation }}
+    errorMessage={validationErrors?.default_catalog}
+    placeholder={t('e.g. hive_metastore')}
+    label={t('Default Catalog')}
+    onChange={changeMethods.onParametersChange}
+    helpText={t('The default catalog that should be used for the connection.')}
+  />
+);
+export const defaultSchemaField = ({
+  required,
+  changeMethods,
+  getValidation,
+  validationErrors,
+  db,
+}: FieldPropTypes) => (
+  <ValidatedInput
+    id="default_schema"
+    name="default_schema"
+    required={required}
+    value={db?.parameters?.default_schema}
+    validationMethods={{ onBlur: getValidation }}
+    errorMessage={validationErrors?.default_schema}
+    placeholder={t('e.g. default')}
+    label={t('Default Schema')}
+    onChange={changeMethods.onParametersChange}
+    helpText={t('The default schema that should be used for the connection.')}
+  />
+);
+export const httpPathField = ({
+  required,
+  changeMethods,
+  getValidation,
+  validationErrors,
+  db,
+}: FieldPropTypes) => {
+  console.error(db);
+  return (
+    <ValidatedInput
+      id="http_path_field"
+      name="http_path_field"
+      required={required}
+      value={db?.parameters?.http_path_field}
+      validationMethods={{ onBlur: getValidation }}
+      errorMessage={validationErrors?.http_path}
+      placeholder={t('e.g. sql/protocolv1/o/12345')}
+      label="HTTP Path"
+      onChange={changeMethods.onParametersChange}
+      helpText={t('Copy the name of the HTTP Path of your cluster.')}
+    />
+  );
+};
 export const usernameField = ({
   required,
   changeMethods,
diff --git 
a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
 
b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
index 509103ea29..aff755b955 100644
--- 
a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
+++ 
b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
@@ -27,10 +27,13 @@ import { Form } from 'src/components/Form';
 import {
   accessTokenField,
   databaseField,
+  defaultCatalogField,
+  defaultSchemaField,
   displayField,
   forceSSLField,
   hostField,
   httpPath,
+  httpPathField,
   passwordField,
   portField,
   queryField,
@@ -47,10 +50,13 @@ export const FormFieldOrder = [
   'host',
   'port',
   'database',
+  'default_catalog',
+  'default_schema',
   'username',
   'password',
   'access_token',
   'http_path',
+  'http_path_field',
   'database_name',
   'credentials_info',
   'service_account_info',
@@ -71,8 +77,11 @@ const SSHTunnelSwitchComponent =
 const FORM_FIELD_MAP = {
   host: hostField,
   http_path: httpPath,
+  http_path_field: httpPathField,
   port: portField,
   database: databaseField,
+  default_catalog: defaultCatalogField,
+  default_schema: defaultSchemaField,
   username: usernameField,
   password: passwordField,
   access_token: accessTokenField,
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx 
b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
index 47c9a8b658..4e1e58ebb8 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
@@ -633,11 +633,23 @@ const DatabaseModal: 
FunctionComponent<DatabaseModalProps> = ({
   const history = useHistory();
 
   const dbModel: DatabaseForm =
+    // TODO: we need a centralized engine in one place
+
+    // first try to match both engine and driver
+    availableDbs?.databases?.find(
+      (available: {
+        engine: string | undefined;
+        default_driver: string | undefined;
+      }) =>
+        available.engine === (isEditMode ? db?.backend : db?.engine) &&
+        available.default_driver === db?.driver,
+    ) ||
+    // alternatively try to match only engine
     availableDbs?.databases?.find(
       (available: { engine: string | undefined }) =>
-        // TODO: we need a centralized engine in one place
         available.engine === (isEditMode ? db?.backend : db?.engine),
-    ) || {};
+    ) ||
+    {};
 
   // Test Connection logic
   const testConnection = () => {
diff --git a/superset-frontend/src/features/databases/types.ts 
b/superset-frontend/src/features/databases/types.ts
index a09ad174a0..c46296a2a4 100644
--- a/superset-frontend/src/features/databases/types.ts
+++ b/superset-frontend/src/features/databases/types.ts
@@ -63,6 +63,9 @@ export type DatabaseObject = {
     host?: string;
     port?: number;
     database?: string;
+    default_catalog?: string;
+    default_schema?: string;
+    http_path_field?: string;
     username?: string;
     password?: string;
     encryption?: boolean;
@@ -126,6 +129,18 @@ export type DatabaseForm = {
         description: string;
         type: string;
       };
+      default_catalog: {
+        description: string;
+        type: string;
+      };
+      default_schema: {
+        description: string;
+        type: string;
+      };
+      http_path_field: {
+        description: string;
+        type: string;
+      };
       host: {
         description: string;
         type: string;
diff --git a/superset/db_engine_specs/databricks.py 
b/superset/db_engine_specs/databricks.py
index 4b2f93ca5d..6fc753c00e 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import json
 from datetime import datetime
-from typing import Any, TYPE_CHECKING, TypedDict
+from typing import Any, TYPE_CHECKING, TypedDict, Union
 
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
@@ -40,10 +40,10 @@ if TYPE_CHECKING:
 
 
 #
-class DatabricksParametersSchema(Schema):
+class DatabricksBaseSchema(Schema):
     """
-    This is the list of fields that are expected
-    from the client in order to build the sqlalchemy string
+    Fields that are required for both Databricks drivers that uses a
+    dynamic form.
     """
 
     access_token = fields.Str(required=True)
@@ -53,44 +53,85 @@ class DatabricksParametersSchema(Schema):
         metadata={"description": __("Database port")},
         validate=Range(min=0, max=2**16, max_inclusive=False),
     )
-    database = fields.Str(required=True)
     encryption = fields.Boolean(
         required=False,
         metadata={"description": __("Use an encrypted connection to the 
database")},
     )
 
 
-class DatabricksPropertiesSchema(DatabricksParametersSchema):
+class DatabricksBaseParametersType(TypedDict):
+    """
+    The parameters are all the keys that do not exist on the Database model.
+    These are used to build the sqlalchemy uri.
+    """
+
+    access_token: str
+    host: str
+    port: int
+    encryption: bool
+
+
+class DatabricksNativeSchema(DatabricksBaseSchema):
     """
-    This is the list of fields expected
-    for successful database creation execution
+    Additional fields required only for the DatabricksNativeEngineSpec.
+    """
+
+    database = fields.Str(required=True)
+
+
+class DatabricksNativePropertiesSchema(DatabricksNativeSchema):
+    """
+    Properties required only for the DatabricksNativeEngineSpec.
     """
 
     http_path = fields.Str(required=True)
 
 
-class DatabricksParametersType(TypedDict):
+class DatabricksNativeParametersType(DatabricksBaseParametersType):
     """
-    The parameters are all the keys that do
-    not exist on the Database model.
-    These are used to build the sqlalchemy uri
+    Additional parameters required only for the DatabricksNativeEngineSpec.
     """
 
-    access_token: str
-    host: str
-    port: int
     database: str
-    encryption: bool
 
 
-class DatabricksPropertiesType(TypedDict):
+class DatabricksNativePropertiesType(TypedDict):
+    """
+    All properties that need to be available to the DatabricksNativeEngineSpec
+    in order tocreate a connection if the dynamic form is used.
+    """
+
+    parameters: DatabricksNativeParametersType
+    extra: str
+
+
+class DatabricksPythonConnectorSchema(DatabricksBaseSchema):
+    """
+    Additional fields required only for the 
DatabricksPythonConnectorEngineSpec.
+    """
+
+    http_path_field = fields.Str(required=True)
+    default_catalog = fields.Str(required=True)
+    default_schema = fields.Str(required=True)
+
+
+class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType):
     """
-    All properties that need to be available to
-    this engine in order to create a connection
-    if the dynamic form is used
+    Additional parameters required only for the 
DatabricksPythonConnectorEngineSpec.
     """
 
-    parameters: DatabricksParametersType
+    http_path_field: str
+    default_catalog: str
+    default_schema: str
+
+
+class DatabricksPythonConnectorPropertiesType(TypedDict):
+    """
+    All properties that need to be available to the 
DatabricksPythonConnectorEngineSpec
+    in order to create a connection if the dynamic form is used.
+    """
+
+    parameters: DatabricksPythonConnectorParametersType
     extra: str
 
 
@@ -125,13 +166,7 @@ class DatabricksHiveEngineSpec(HiveEngineSpec):
     _time_grain_expressions = time_grain_expressions
 
 
-class DatabricksODBCEngineSpec(BaseEngineSpec):
-    engine_name = "Databricks SQL Endpoint"
-
-    engine = "databricks"
-    drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
-    default_driver = "pyodbc"
-
+class DatabricksBaseEngineSpec(BaseEngineSpec):
     _time_grain_expressions = time_grain_expressions
 
     @classmethod
@@ -145,20 +180,23 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
         return HiveEngineSpec.epoch_to_dttm()
 
 
-class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec):
-    engine_name = "Databricks"
+class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec):
+    engine_name = "Databricks SQL Endpoint"
 
     engine = "databricks"
-    drivers = {"connector": "Native all-purpose driver"}
-    default_driver = "connector"
+    drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
+    default_driver = "pyodbc"
 
-    parameters_schema = DatabricksParametersSchema()
-    properties_schema = DatabricksPropertiesSchema()
 
-    sqlalchemy_uri_placeholder = (
-        
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
-    )
+class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, 
DatabricksBaseEngineSpec):
+    default_driver = ""
     encryption_parameters = {"ssl": "1"}
+    required_parameters = {"access_token", "host", "port"}
+    context_key_mapping = {
+        "access_token": "password",
+        "host": "hostname",
+        "port": "port",
+    }
 
     @staticmethod
     def get_extra_params(database: Database) -> dict[str, Any]:
@@ -190,30 +228,6 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
             database, inspector, schema
         ) - cls.get_view_names(database, inspector, schema)
 
-    @classmethod
-    def build_sqlalchemy_uri(  # type: ignore
-        cls, parameters: DatabricksParametersType, *_
-    ) -> str:
-        query = {}
-        if parameters.get("encryption"):
-            if not cls.encryption_parameters:
-                raise Exception(  # pylint: disable=broad-exception-raised
-                    "Unable to build a URL with encryption enabled"
-                )
-            query.update(cls.encryption_parameters)
-
-        return str(
-            URL.create(
-                f"{cls.engine}+{cls.default_driver}".rstrip("+"),
-                username="token",
-                password=parameters.get("access_token"),
-                host=parameters["host"],
-                port=parameters["port"],
-                database=parameters["database"],
-                query=query,
-            )
-        )
-
     @classmethod
     def extract_errors(
         cls, ex: Exception, context: dict[str, Any] | None = None
@@ -224,13 +238,10 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
         # access_token isn't currently parseable from the
         # databricks error response, but adding it in here
         # for reference if their error message changes
-        context = {
-            "host": context.get("hostname"),
-            "access_token": context.get("password"),
-            "port": context.get("port"),
-            "username": context.get("username"),
-            "database": context.get("database"),
-        }
+
+        for key, value in cls.context_key_mapping.items():
+            context[key] = context.get(value)
+
         for regex, (message, error_type, extra) in cls.custom_errors.items():
             match = regex.search(raw_message)
             if match:
@@ -254,32 +265,18 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
             )
         ]
 
-    @classmethod
-    def get_parameters_from_uri(  # type: ignore
-        cls, uri: str, *_, **__
-    ) -> DatabricksParametersType:
-        url = make_url_safe(uri)
-        encryption = all(
-            item in url.query.items() for item in 
cls.encryption_parameters.items()
-        )
-        return {
-            "access_token": url.password,
-            "host": url.host,
-            "port": url.port,
-            "database": url.database,
-            "encryption": encryption,
-        }
-
     @classmethod
     def validate_parameters(  # type: ignore
         cls,
-        properties: DatabricksPropertiesType,
+        properties: Union[
+            DatabricksNativePropertiesType,
+            DatabricksPythonConnectorPropertiesType,
+        ],
     ) -> list[SupersetError]:
         errors: list[SupersetError] = []
-        required = {"access_token", "host", "port", "database", "extra"}
-        extra = json.loads(properties.get("extra", "{}"))
-        engine_params = extra.get("engine_params", {})
-        connect_args = engine_params.get("connect_args", {})
+        if extra := json.loads(properties.get("extra")):  # type: ignore
+            engine_params = extra.get("engine_params", {})
+            connect_args = engine_params.get("connect_args", {})
         parameters = {
             **properties,
             **properties.get("parameters", {}),
@@ -289,7 +286,7 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
 
         present = {key for key in parameters if parameters.get(key, ())}
 
-        if missing := sorted(required - present):
+        if missing := sorted(cls.required_parameters - present):
             errors.append(
                 SupersetError(
                     message=f'One or more parameters are missing: {", 
".join(missing)}',
@@ -351,6 +348,69 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
             )
         return errors
 
+
+class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
+    engine = "databricks"
+    engine_name = "Databricks"
+    drivers = {"connector": "Native all-purpose driver"}
+    default_driver = "connector"
+
+    parameters_schema = DatabricksNativeSchema()
+    properties_schema = DatabricksNativePropertiesSchema()
+
+    sqlalchemy_uri_placeholder = (
+        
"databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
+    )
+    context_key_mapping = {
+        **DatabricksDynamicBaseEngineSpec.context_key_mapping,
+        "database": "database",
+        "username": "username",
+    }
+    required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters 
| {
+        "database",
+        "extra",
+    }
+
+    @classmethod
+    def build_sqlalchemy_uri(  # type: ignore
+        cls, parameters: DatabricksNativeParametersType, *_
+    ) -> str:
+        query = {}
+        if parameters.get("encryption"):
+            if not cls.encryption_parameters:
+                raise Exception(  # pylint: disable=broad-exception-raised
+                    "Unable to build a URL with encryption enabled"
+                )
+            query.update(cls.encryption_parameters)
+
+        return str(
+            URL.create(
+                f"{cls.engine}+{cls.default_driver}".rstrip("+"),
+                username="token",
+                password=parameters.get("access_token"),
+                host=parameters["host"],
+                port=parameters["port"],
+                database=parameters["database"],
+                query=query,
+            )
+        )
+
+    @classmethod
+    def get_parameters_from_uri(  # type: ignore
+        cls, uri: str, *_, **__
+    ) -> DatabricksNativeParametersType:
+        url = make_url_safe(uri)
+        encryption = all(
+            item in url.query.items() for item in 
cls.encryption_parameters.items()
+        )
+        return {
+            "access_token": url.password,
+            "host": url.host,
+            "port": url.port,
+            "database": url.database,
+            "encryption": encryption,
+        }
+
     @classmethod
     def parameters_json_schema(cls) -> Any:
         """
@@ -367,3 +427,78 @@ class DatabricksNativeEngineSpec(BasicParametersMixin, 
DatabricksODBCEngineSpec)
         )
         spec.components.schema(cls.__name__, schema=cls.properties_schema)
         return spec.to_dict()["components"]["schemas"][cls.__name__]
+
+
+class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
+    engine = "databricks"
+    engine_name = "Databricks Python Connector"
+    default_driver = "databricks-sql-python"
+    drivers = {"databricks-sql-python": "Databricks SQL Python"}
+
+    parameters_schema = DatabricksPythonConnectorSchema()
+
+    sqlalchemy_uri_placeholder = (
+        "databricks://token:{access_token}@{host}:{port}?http_path={http_path}"
+        "&catalog={default_catalog}&schema={default_schema}"
+    )
+
+    context_key_mapping = {
+        **DatabricksDynamicBaseEngineSpec.context_key_mapping,
+        "default_catalog": "catalog",
+        "default_schema": "schema",
+        "http_path_field": "http_path",
+    }
+
+    required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters 
| {
+        "default_catalog",
+        "default_schema",
+        "http_path_field",
+    }
+
+    @classmethod
+    def build_sqlalchemy_uri(  # type: ignore
+        cls, parameters: DatabricksPythonConnectorParametersType, *_
+    ) -> str:
+        query = {}
+        if http_path := parameters.get("http_path_field"):
+            query["http_path"] = http_path
+        if catalog := parameters.get("default_catalog"):
+            query["catalog"] = catalog
+        if schema := parameters.get("default_schema"):
+            query["schema"] = schema
+        if parameters.get("encryption"):
+            query.update(cls.encryption_parameters)
+
+        return str(
+            URL.create(
+                cls.engine,
+                username="token",
+                password=parameters.get("access_token"),
+                host=parameters["host"],
+                port=parameters["port"],
+                query=query,
+            )
+        )
+
+    @classmethod
+    def get_parameters_from_uri(  # type: ignore
+        cls, uri: str, *_: Any, **__: Any
+    ) -> DatabricksPythonConnectorParametersType:
+        url = make_url_safe(uri)
+        query = {
+            key: value
+            for (key, value) in url.query.items()
+            if (key, value) not in cls.encryption_parameters.items()
+        }
+        encryption = all(
+            item in url.query.items() for item in 
cls.encryption_parameters.items()
+        )
+        return {
+            "access_token": url.password,
+            "host": url.host,
+            "port": url.port,
+            "http_path_field": query["http_path"],
+            "default_catalog": query["catalog"],
+            "default_schema": query["schema"],
+            "encryption": encryption,
+        }
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py 
b/tests/unit_tests/db_engine_specs/test_databricks.py
index de06f919be..8709833d3f 100644
--- a/tests/unit_tests/db_engine_specs/test_databricks.py
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -35,13 +35,13 @@ def test_get_parameters_from_uri() -> None:
     """
     from superset.db_engine_specs.databricks import (
         DatabricksNativeEngineSpec,
-        DatabricksParametersType,
+        DatabricksNativeParametersType,
     )
 
     parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
         "databricks+connector://token:abc12345@my_hostname:1234/test"
     )
-    assert parameters == DatabricksParametersType(
+    assert parameters == DatabricksNativeParametersType(
         {
             "access_token": "abc12345",
             "host": "my_hostname",
@@ -60,10 +60,10 @@ def test_build_sqlalchemy_uri() -> None:
     """
     from superset.db_engine_specs.databricks import (
         DatabricksNativeEngineSpec,
-        DatabricksParametersType,
+        DatabricksNativeParametersType,
     )
 
-    parameters = DatabricksParametersType(
+    parameters = DatabricksNativeParametersType(
         {
             "access_token": "abc12345",
             "host": "my_hostname",

Reply via email to