This is an automated email from the ASF dual-hosted git repository. vavila pushed a commit to branch feat/list-only-allowed-schemas-for-upload in repository https://gitbox.apache.org/repos/asf/superset.git
commit e2559552b185bc056e10afda7357e841f86dfe8b Author: Vitor Avila <[email protected]> AuthorDate: Mon Mar 17 13:15:59 2025 -0300 feat(file uploads): List only allowed schemas in the file uploads dialog --- superset/databases/api.py | 24 +++++++++- superset/databases/schemas.py | 1 + tests/integration_tests/databases/api_tests.py | 64 ++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/superset/databases/api.py b/superset/databases/api.py index 5c3e024e73..a573823fc5 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -776,18 +776,38 @@ class DatabaseRestApi(BaseSupersetModelRestApi): if not database: return self.response_404() try: - catalog = kwargs["rison"].get("catalog") + params = kwargs["rison"] + catalog = params.get("catalog") schemas = database.get_all_schema_names( catalog=catalog, cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout or None, - force=kwargs["rison"].get("force", False), + force=params.get("force", False), ) schemas = security_manager.get_schemas_accessible_by_user( database, catalog, schemas, ) + if params.get("upload_allowed"): + if not database.allow_file_upload: + return self.response( + 400, + message="File upload is disabled on this database connection", + ) + if allowed_schemas := database.get_schema_access_for_file_upload(): + # some databases might return the list of schemas in uppercase, + # while the list of allowed schemas is manually inputted so + # could be lowercase + allowed_schemas = {schema.lower() for schema in allowed_schemas} + return self.response( + 200, + result=[ + schema + for schema in schemas + if schema.lower() in allowed_schemas + ], + ) return self.response(200, result=list(schemas)) except OperationalError: return self.response( diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index c9df9fcbad..765f9dde83 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -64,6 +64,7 @@ database_schemas_query_schema = { "type": "object", "properties": { "force": {"type": "boolean"}, + "upload_allowed": {"type": "boolean"}, "catalog": {"type": "string"}, }, } diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 76e76dca18..d9081d72b1 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -2083,6 +2083,70 @@ class TestDatabaseApi(SupersetTestCase): ) assert rv.status_code == 400 + @pytest.mark.parametrize( + "all_schemas,schemas_allowed_for_csv,result", + [ + ( + ["schema_1", "schema_2", "schema_3"], + [], + ["schema_1", "schema_2", "schema_3"], + ), + (["schema_1", "schema_2", "schema_3"], ["schema_2"], ["schema_2"]), + ], + ) + def test_database_schemas_upload_allowed_filter( + self, + all_schemas: list[str], + schemas_allowed_for_csv: list[str], + result: list[str], + ): + """ + Database API: Test database schemas when filtering for upload allowed + """ + with self.create_app().app_context(): + example_db = get_example_database() + + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": schemas_allowed_for_csv, + } + self.login(ADMIN_USERNAME) + database = self.insert_database( + "database_with_upload", + example_db.sqlalchemy_uri_decrypted, + extra=json.dumps(extra), + allow_file_upload=True, + ) + db.session.commit() + yield database + + mock.patch.object( + database, "get_all_schema_names", return_value=all_schemas + ) + arguments = {"upload_allowed": True} + uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + data = json.loads(rv.data.decode("utf-8")) + assert data["result"] == result + db.session.delete(database) + db.session.commit() + + def test_database_schemas_upload_allowed_filter_disabled(self): + """ + Database API: Test database schemas when filtering for upload allowed + for a DB connection that has file uploads disabled + """ + database = db.session.query(Database).filter_by(database_name="examples").one() + self.login(ADMIN_USERNAME) + arguments = {"upload_allowed": True} + uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + assert rv.status_code == 400 + data = json.loads(rv.data.decode("utf-8")) + assert data["message"] == "File upload is disabled on this database connection" + def test_database_tables(self): """ Database API: Test database tables
