github-advanced-security[bot] commented on code in PR #34599: URL: https://github.com/apache/superset/pull/34599#discussion_r2372522941
########## superset/llms/api.py: ########## @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import logging +from typing import Any, Dict + +from flask import request, Response +from flask_appbuilder.api import expose, protect, rison, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import fields, Schema, post_load, ValidationError + +from superset import db +from superset.models.core import CustomLlmProvider +from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics +from flask_appbuilder.api.schemas import get_list_schema + +logger = logging.getLogger(__name__) + + +class CustomLlmProviderSchema(Schema): + id = fields.Integer(dump_only=True) + name = fields.String(required=True) + endpoint_url = fields.String(required=True) + request_template = fields.String(required=True) + response_path = fields.String(required=True) + headers = fields.String(allow_none=True) + models = fields.String(required=True) + system_instructions = fields.String(allow_none=True) + timeout = fields.Integer(allow_none=True, missing=30) + enabled = fields.Boolean(missing=True) + created_on = fields.DateTime(dump_only=True) + changed_on = fields.DateTime(dump_only=True) + + @post_load + def validate_json_fields(self, data, **kwargs): + """Validate JSON fields.""" + # Validate request_template + try: + json.loads(data["request_template"]) + except json.JSONDecodeError: + raise ValidationError("request_template must be valid JSON") + + # Validate headers if provided + if data.get("headers"): + try: + json.loads(data["headers"]) + except json.JSONDecodeError: + raise ValidationError("headers must be valid JSON") + + # Validate models + try: + models = json.loads(data["models"]) + if not isinstance(models, dict): + raise ValidationError("models must be a JSON object") + except json.JSONDecodeError: + raise ValidationError("models must be valid JSON") + + return data + + +class CustomLlmProviderRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(CustomLlmProvider) + resource_name = "custom_llm_provider" + allow_browser_login = True + + class_permission_name = "CustomLlmProvider" + method_permission_name = { + "get": "read", + "get_list": "read", + "post": "write", + "put": "write", + "delete": "write", + } + + add_columns = [ + "name", + "endpoint_url", + "request_template", + "response_path", + "headers", + "models", + "system_instructions", + "timeout", + "enabled", + ] + + edit_columns = add_columns + + list_columns = [ + "id", + "name", + "endpoint_url", + "enabled", + "created_on", + "changed_on", + ] + + show_columns = [ + "id", + "name", + "endpoint_url", + "request_template", + "response_path", + "headers", + "models", + "system_instructions", + "timeout", + "enabled", + "created_on", + "changed_on", + ] + + openapi_spec_tag = "Custom LLM Providers" + + add_model_schema = CustomLlmProviderSchema() + edit_model_schema = CustomLlmProviderSchema() + show_model_schema = CustomLlmProviderSchema() + + @expose("/test", methods=("POST",)) + @protect() + @safe + @statsd_metrics + def test_connection(self) -> Response: + """Test connection to a custom LLM provider.""" + try: + data = request.get_json() + + # Validate required fields + required_fields = ["endpoint_url", "request_template", "response_path"] + for field in required_fields: + if field not in data: + return self.response_400(f"Missing required field: {field}") + + # Validate JSON fields + try: + request_template = json.loads(data["request_template"]) + except json.JSONDecodeError: + return self.response_400("request_template must be valid JSON") + + headers = {"Content-Type": "application/json"} + if data.get("headers"): + try: + custom_headers = json.loads(data["headers"]) + headers.update(custom_headers) + except json.JSONDecodeError: + return self.response_400("headers must be valid JSON") + + # Create a simple test request + test_request = { + "model": "test", + "messages": [{"role": "user", "content": "SELECT 1"}] + } + + # Substitute template variables if needed + test_data = request_template.copy() + for key, value in test_data.items(): + if isinstance(value, str) and "{" in value: + test_data[key] = value.format( + model="test", + messages=test_request["messages"], + api_key="test" + ) + + import requests + timeout = data.get("timeout", 30) + + try: + response = requests.post( + data["endpoint_url"], + json=test_data, + headers=headers, + timeout=timeout + ) Review Comment: ## Full server-side request forgery The full URL of this request depends on a [user-provided value](1). [Show more details](https://github.com/apache/superset/security/code-scanning/2056) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
