This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch pre-cost-estimation in repository https://gitbox.apache.org/repos/asf/superset.git
commit 992226527f78ab03fc1bfd5f774ad13e39026d23 Author: Beto Dealmeida <[email protected]> AuthorDate: Tue Jul 29 12:49:55 2025 -0400 feat: run pre-query cost estimation --- superset-frontend/src/SqlLab/actions/sqlLab.js | 43 ++++ .../CostWarningModal/CostWarningModal.test.tsx | 131 ++++++++++++ .../SqlLab/components/CostWarningModal/index.tsx | 166 +++++++++++++++ .../src/SqlLab/components/SqlEditor/index.tsx | 88 +++++++- .../src/SqlLab/reducers/getInitialState.ts | 1 + superset-frontend/src/SqlLab/reducers/sqlLab.js | 45 ++++ superset/commands/sql_lab/check_cost_threshold.py | 230 +++++++++++++++++++++ superset/config.py | 12 ++ superset/sqllab/api.py | 63 ++++++ 9 files changed, 778 insertions(+), 1 deletion(-) diff --git a/superset-frontend/src/SqlLab/actions/sqlLab.js b/superset-frontend/src/SqlLab/actions/sqlLab.js index 82d103ed74..602da34245 100644 --- a/superset-frontend/src/SqlLab/actions/sqlLab.js +++ b/superset-frontend/src/SqlLab/actions/sqlLab.js @@ -97,6 +97,10 @@ export const COST_ESTIMATE_STARTED = 'COST_ESTIMATE_STARTED'; export const COST_ESTIMATE_RETURNED = 'COST_ESTIMATE_RETURNED'; export const COST_ESTIMATE_FAILED = 'COST_ESTIMATE_FAILED'; +export const COST_THRESHOLD_CHECK_STARTED = 'COST_THRESHOLD_CHECK_STARTED'; +export const COST_THRESHOLD_CHECK_RETURNED = 'COST_THRESHOLD_CHECK_RETURNED'; +export const COST_THRESHOLD_CHECK_FAILED = 'COST_THRESHOLD_CHECK_FAILED'; + export const CREATE_DATASOURCE_STARTED = 'CREATE_DATASOURCE_STARTED'; export const CREATE_DATASOURCE_SUCCESS = 'CREATE_DATASOURCE_SUCCESS'; export const CREATE_DATASOURCE_FAILED = 'CREATE_DATASOURCE_FAILED'; @@ -233,6 +237,45 @@ export function estimateQueryCost(queryEditor) { }; } +export function checkCostThreshold(queryEditor) { + return (dispatch, getState) => { + const { dbId, catalog, schema, sql, selectedText, templateParams } = + getUpToDateQuery(getState(), queryEditor); + const requestSql = selectedText || sql; + const postPayload = { + database_id: dbId, + catalog, + schema, + sql: requestSql, + template_params: JSON.parse(templateParams || '{}'), + }; + return Promise.all([ + dispatch({ type: COST_THRESHOLD_CHECK_STARTED, query: queryEditor }), + SupersetClient.post({ + endpoint: '/api/v1/sqllab/check_cost_threshold/', + body: JSON.stringify(postPayload), + headers: { 'Content-Type': 'application/json' }, + }) + .then(({ json }) => + dispatch({ type: COST_THRESHOLD_CHECK_RETURNED, query: queryEditor, json }), + ) + .catch(response => + getClientErrorObject(response).then(error => { + const message = + error.error || + error.statusText || + t('Failed at checking cost threshold'); + return dispatch({ + type: COST_THRESHOLD_CHECK_FAILED, + query: queryEditor, + error: message, + }); + }), + ), + ]); + }; +} + export function clearInactiveQueries(interval) { return { type: CLEAR_INACTIVE_QUERIES, interval }; } diff --git a/superset-frontend/src/SqlLab/components/CostWarningModal/CostWarningModal.test.tsx b/superset-frontend/src/SqlLab/components/CostWarningModal/CostWarningModal.test.tsx new file mode 100644 index 0000000000..f8589646c5 --- /dev/null +++ b/superset-frontend/src/SqlLab/components/CostWarningModal/CostWarningModal.test.tsx @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { render, screen, fireEvent } from '@testing-library/react'; +import { ThemeProvider } from '@superset-ui/core'; +import { theme } from 'src/preamble'; +import CostWarningModal from './index'; + +const mockProps = { + visible: true, + onHide: jest.fn(), + onProceed: jest.fn(), + warningMessage: 'This query will scan 10 GB of data, which exceeds the threshold of 5 GB.', + thresholdInfo: { + bytes_threshold: 5 * 1024 ** 3, // 5 GB + estimated_bytes: 10 * 1024 ** 3, // 10 GB + }, +}; + +const renderWithTheme = (ui: React.ReactElement) => + render(<ThemeProvider theme={theme}>{ui}</ThemeProvider>); + +describe('CostWarningModal', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('renders with warning message', () => { + renderWithTheme(<CostWarningModal {...mockProps} />); + + expect(screen.getByText('Query Cost Warning')).toBeInTheDocument(); + expect(screen.getByText(mockProps.warningMessage)).toBeInTheDocument(); + }); + + it('shows threshold details when provided', () => { + renderWithTheme(<CostWarningModal {...mockProps} />); + + expect(screen.getByText('Threshold Details:')).toBeInTheDocument(); + expect(screen.getByText('Data to scan:')).toBeInTheDocument(); + expect(screen.getByText('10.0 GB')).toBeInTheDocument(); + expect(screen.getByText('5.0 GB')).toBeInTheDocument(); + }); + + it('disables proceed button until checkbox is checked', () => { + renderWithTheme(<CostWarningModal {...mockProps} />); + + const proceedButton = screen.getByText('Run Query Anyway'); + const checkbox = screen.getByRole('checkbox'); + + expect(proceedButton).toBeDisabled(); + + fireEvent.click(checkbox); + expect(proceedButton).not.toBeDisabled(); + }); + + it('calls onProceed when proceed button is clicked with checkbox checked', () => { + renderWithTheme(<CostWarningModal {...mockProps} />); + + const checkbox = screen.getByRole('checkbox'); + const proceedButton = screen.getByText('Run Query Anyway'); + + fireEvent.click(checkbox); + fireEvent.click(proceedButton); + + expect(mockProps.onProceed).toHaveBeenCalledTimes(1); + }); + + it('calls onHide when cancel button is clicked', () => { + renderWithTheme(<CostWarningModal {...mockProps} />); + + const cancelButton = screen.getByText('Cancel'); + fireEvent.click(cancelButton); + + expect(mockProps.onHide).toHaveBeenCalledTimes(1); + }); + + it('renders without threshold details when not provided', () => { + const propsWithoutThreshold = { + ...mockProps, + thresholdInfo: undefined, + }; + + renderWithTheme(<CostWarningModal {...propsWithoutThreshold} />); + + expect(screen.queryByText('Threshold Details:')).not.toBeInTheDocument(); + }); + + it('shows default message when warningMessage is null', () => { + const propsWithNoMessage = { + ...mockProps, + warningMessage: null, + }; + + renderWithTheme(<CostWarningModal {...propsWithNoMessage} />); + + expect(screen.getByText('This query may be expensive to run.')).toBeInTheDocument(); + }); + + it('handles cost threshold details', () => { + const propsWithCostThreshold = { + ...mockProps, + thresholdInfo: { + cost_threshold: 100, + estimated_cost: 250, + }, + }; + + renderWithTheme(<CostWarningModal {...propsWithCostThreshold} />); + + expect(screen.getByText('Estimated cost:')).toBeInTheDocument(); + expect(screen.getByText('250')).toBeInTheDocument(); + expect(screen.getByText('Cost threshold:')).toBeInTheDocument(); + expect(screen.getByText('100')).toBeInTheDocument(); + }); +}); diff --git a/superset-frontend/src/SqlLab/components/CostWarningModal/index.tsx b/superset-frontend/src/SqlLab/components/CostWarningModal/index.tsx new file mode 100644 index 0000000000..3b1f9876ba --- /dev/null +++ b/superset-frontend/src/SqlLab/components/CostWarningModal/index.tsx @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { useState } from 'react'; +import { styled, t } from '@superset-ui/core'; +import { Button, Modal, Checkbox } from '@superset-ui/core/components'; +import { ModalTitleWithIcon } from 'src/components/ModalTitleWithIcon'; + +const StyledModal = styled(Modal)` + .ant-modal-body { + padding: 24px; + } +`; + +const WarningContent = styled.div` + margin: 16px 0; + font-size: 14px; + line-height: 1.5; +`; + +const DetailsSection = styled.div` + margin: 16px 0; + padding: 12px; + background-color: ${({ theme }) => theme.colors.grayscale.light4}; + border-radius: 4px; + font-size: 12px; +`; + +const CheckboxWrapper = styled.div` + margin: 16px 0; +`; + +interface CostWarningModalProps { + visible: boolean; + onHide: () => void; + onProceed: () => void; + warningMessage: string | null; + thresholdInfo?: { + bytes_threshold?: number; + estimated_bytes?: number; + cost_threshold?: number; + estimated_cost?: number; + }; +} + +export default function CostWarningModal({ + visible, + onHide, + onProceed, + warningMessage, + thresholdInfo, +}: CostWarningModalProps) { + const [proceedAnyway, setProceedAnyway] = useState(false); + + const handleProceed = () => { + if (proceedAnyway) { + onProceed(); + } + }; + + const formatBytes = (bytes: number) => { + if (bytes < 1024) return `${bytes} B`; + if (bytes < 1024 ** 2) return `${(bytes / 1024).toFixed(1)} KB`; + if (bytes < 1024 ** 3) return `${(bytes / 1024 ** 2).toFixed(1)} MB`; + if (bytes < 1024 ** 4) return `${(bytes / 1024 ** 3).toFixed(1)} GB`; + if (bytes < 1024 ** 5) return `${(bytes / 1024 ** 4).toFixed(1)} TB`; + return `${(bytes / 1024 ** 5).toFixed(1)} PB`; + }; + + const renderThresholdDetails = () => { + if (!thresholdInfo) return null; + + const details = []; + + if (thresholdInfo.bytes_threshold && thresholdInfo.estimated_bytes) { + details.push( + <div key="bytes"> + <strong>{t('Data to scan:')}</strong> {formatBytes(thresholdInfo.estimated_bytes)} + <br /> + <strong>{t('Threshold:')}</strong> {formatBytes(thresholdInfo.bytes_threshold)} + </div> + ); + } + + if (thresholdInfo.cost_threshold && thresholdInfo.estimated_cost) { + details.push( + <div key="cost"> + <strong>{t('Estimated cost:')}</strong> {thresholdInfo.estimated_cost} + <br /> + <strong>{t('Cost threshold:')}</strong> {thresholdInfo.cost_threshold} + </div> + ); + } + + return details.length > 0 ? ( + <DetailsSection> + <div style={{ marginBottom: '8px' }}> + <strong>{t('Threshold Details:')}</strong> + </div> + {details.map((detail, index) => ( + <div key={index} style={{ marginBottom: index < details.length - 1 ? '8px' : '0' }}> + {detail} + </div> + ))} + </DetailsSection> + ) : null; + }; + + return ( + <StyledModal + show={visible} + onHide={onHide} + title={ + <ModalTitleWithIcon + icon="exclamation-triangle" + title={t('Query Cost Warning')} + /> + } + footer={ + <> + <Button onClick={onHide}> + {t('Cancel')} + </Button> + <Button + buttonStyle="primary" + onClick={handleProceed} + disabled={!proceedAnyway} + > + {t('Run Query Anyway')} + </Button> + </> + } + > + <WarningContent> + {warningMessage || t('This query may be expensive to run.')} + </WarningContent> + + {renderThresholdDetails()} + + <CheckboxWrapper> + <Checkbox + checked={proceedAnyway} + onChange={(e) => setProceedAnyway(e.target.checked)} + > + {t('I understand the cost implications and want to proceed anyway')} + </Checkbox> + </CheckboxWrapper> + </StyledModal> + ); +} diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx index b2c77346f9..c49fa0ff04 100644 --- a/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx +++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.tsx @@ -71,6 +71,7 @@ import { addNewQueryEditor, CtasEnum, estimateQueryCost, + checkCostThreshold, persistEditorHeight, postStopQuery, queryEditorSetAutorun, @@ -123,6 +124,7 @@ import SouthPane from '../SouthPane'; import SaveQuery, { QueryPayload } from '../SaveQuery'; import ScheduleQueryButton from '../ScheduleQueryButton'; import EstimateQueryCostButton from '../EstimateQueryCostButton'; +import CostWarningModal from '../CostWarningModal'; import ShareSqlLabQuery from '../ShareSqlLabQuery'; import SqlEditorLeftBar from '../SqlEditorLeftBar'; import AceEditorWrapper from '../AceEditorWrapper'; @@ -270,6 +272,7 @@ const SqlEditor: FC<Props> = ({ hideLeftBar, currentQueryEditorId, hasSqlStatement, + costThresholdData, } = useSelector< SqlLabRootState, { @@ -278,8 +281,9 @@ const SqlEditor: FC<Props> = ({ hideLeftBar?: boolean; currentQueryEditorId: QueryEditor['id']; hasSqlStatement: boolean; + costThresholdData?: any; } - >(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory } }) => { + >(({ sqlLab: { unsavedQueryEditor, databases, queries, tabHistory, queryCostThresholds } }) => { let { dbId, latestQueryId, hideLeftBar } = queryEditor; if (unsavedQueryEditor?.id === queryEditor.id) { dbId = unsavedQueryEditor.dbId || dbId; @@ -295,6 +299,7 @@ const SqlEditor: FC<Props> = ({ latestQuery: queries[latestQueryId || ''], hideLeftBar, currentQueryEditorId: tabHistory.slice(-1)[0], + costThresholdData: queryCostThresholds[queryEditor.id], }; }, shallowEqual); @@ -317,6 +322,11 @@ const SqlEditor: FC<Props> = ({ ); const [showCreateAsModal, setShowCreateAsModal] = useState(false); const [createAs, setCreateAs] = useState(''); + const [showCostWarningModal, setShowCostWarningModal] = useState(false); + const [costWarningData, setCostWarningData] = useState<{ + warningMessage: string | null; + thresholdInfo?: any; + } | null>(null); const currentSQL = useRef<string>(queryEditor.sql); const showEmptyState = useMemo( () => !database || isEmpty(database), @@ -330,7 +340,69 @@ const SqlEditor: FC<Props> = ({ const isTempId = (value: unknown): boolean => Number.isNaN(Number(value)); + const checkCostThresholdAndRun = useCallback( + (ctasArg = false, ctas_method = CtasEnum.Table) => { + if (!database) { + return; + } + + // Check if cost threshold checking is enabled via feature flag or configuration + // For now, we'll implement the logic directly + dispatch(checkCostThreshold(queryEditor)).then(([_, response]) => { + if (response && response.json) { + const { exceeds_threshold, formatted_warning, threshold_info } = response.json; + + if (exceeds_threshold && formatted_warning) { + // Show warning modal + setCostWarningData({ + warningMessage: formatted_warning, + thresholdInfo: threshold_info, + }); + setShowCostWarningModal(true); + return; + } + } + + // If no threshold exceeded or checking failed, proceed with query + dispatch( + runQueryFromSqlEditor( + database, + queryEditor, + defaultQueryLimit, + ctasArg ? ctas : '', + ctasArg, + ctas_method, + ), + ); + dispatch(setActiveSouthPaneTab('Results')); + }).catch(() => { + // If cost checking fails, proceed with query anyway + dispatch( + runQueryFromSqlEditor( + database, + queryEditor, + defaultQueryLimit, + ctasArg ? ctas : '', + ctasArg, + ctas_method, + ), + ); + dispatch(setActiveSouthPaneTab('Results')); + }); + }, + [ctas, database, defaultQueryLimit, dispatch, queryEditor], + ); + const startQuery = useCallback( + (ctasArg = false, ctas_method = CtasEnum.Table) => { + // Use cost threshold checking for regular queries + checkCostThresholdAndRun(ctasArg, ctas_method); + }, + [checkCostThresholdAndRun], + ); + + // Direct query execution without cost checking (for modal "proceed anyway") + const executeQueryDirectly = useCallback( (ctasArg = false, ctas_method = CtasEnum.Table) => { if (!database) { return; @@ -1121,6 +1193,20 @@ const SqlEditor: FC<Props> = ({ <span>{t('Name')}</span> <Input placeholder={createModalPlaceHolder} onChange={ctasChanged} /> </Modal> + <CostWarningModal + visible={showCostWarningModal} + onHide={() => { + setShowCostWarningModal(false); + setCostWarningData(null); + }} + onProceed={() => { + setShowCostWarningModal(false); + setCostWarningData(null); + executeQueryDirectly(); + }} + warningMessage={costWarningData?.warningMessage || null} + thresholdInfo={costWarningData?.thresholdInfo} + /> </StyledSqlEditor> ); }; diff --git a/superset-frontend/src/SqlLab/reducers/getInitialState.ts b/superset-frontend/src/SqlLab/reducers/getInitialState.ts index 361cc5621d..60e43361fa 100644 --- a/superset-frontend/src/SqlLab/reducers/getInitialState.ts +++ b/superset-frontend/src/SqlLab/reducers/getInitialState.ts @@ -264,6 +264,7 @@ export default function getInitialState({ queriesLastUpdate: Date.now(), editorTabLastUpdatedAt, queryCostEstimates: {}, + queryCostThresholds: {}, unsavedQueryEditor, lastUpdatedActiveTab, destroyedQueryEditors, diff --git a/superset-frontend/src/SqlLab/reducers/sqlLab.js b/superset-frontend/src/SqlLab/reducers/sqlLab.js index c77f708653..45d827a204 100644 --- a/superset-frontend/src/SqlLab/reducers/sqlLab.js +++ b/superset-frontend/src/SqlLab/reducers/sqlLab.js @@ -315,6 +315,51 @@ export default function sqlLabReducer(state = {}, action) { }, }; }, + [actions.COST_THRESHOLD_CHECK_STARTED]() { + return { + ...state, + queryCostThresholds: { + ...state.queryCostThresholds, + [action.query.id]: { + completed: false, + exceedsThreshold: false, + thresholdInfo: null, + formattedWarning: null, + error: null, + }, + }, + }; + }, + [actions.COST_THRESHOLD_CHECK_RETURNED]() { + return { + ...state, + queryCostThresholds: { + ...state.queryCostThresholds, + [action.query.id]: { + completed: true, + exceedsThreshold: action.json.exceeds_threshold, + thresholdInfo: action.json.threshold_info, + formattedWarning: action.json.formatted_warning, + error: null, + }, + }, + }; + }, + [actions.COST_THRESHOLD_CHECK_FAILED]() { + return { + ...state, + queryCostThresholds: { + ...state.queryCostThresholds, + [action.query.id]: { + completed: false, + exceedsThreshold: false, + thresholdInfo: null, + formattedWarning: null, + error: action.error, + }, + }, + }; + }, [actions.START_QUERY]() { let newState = { ...state }; if (action.query.sqlEditorId) { diff --git a/superset/commands/sql_lab/check_cost_threshold.py b/superset/commands/sql_lab/check_cost_threshold.py new file mode 100644 index 0000000000..b5071964ac --- /dev/null +++ b/superset/commands/sql_lab/check_cost_threshold.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from typing import Any, TypedDict + +from superset import app +from superset.commands.base import BaseCommand +from superset.commands.sql_lab.estimate import QueryEstimationCommand, EstimateQueryCostType + +config = app.config +logger = logging.getLogger(__name__) + + +class CostThresholdResult(TypedDict): + exceeds_threshold: bool + estimated_cost: list[dict[str, Any]] + threshold_info: dict[str, Any] + formatted_warning: str | None + + +class QueryCostThresholdCheckCommand(BaseCommand): + """ + Command to check if a query's estimated cost exceeds configured thresholds. + """ + + _estimation_command: QueryEstimationCommand + + def __init__(self, estimation_params: EstimateQueryCostType) -> None: + self._estimation_command = QueryEstimationCommand(estimation_params) + + def validate(self) -> None: + # Use the estimation command's validation + self._estimation_command.validate() + + def run(self) -> CostThresholdResult: + """ + Check if query cost exceeds thresholds. + + Returns a result indicating whether the query exceeds cost thresholds + and provides information for user warnings. + """ + self.validate() + + # Check if cost checking is enabled + if not config.get("SQLLAB_QUERY_COST_CHECKING_ENABLED", False): + return self._create_empty_result() + + estimated_cost = self._get_estimated_cost() + if not estimated_cost: + return self._create_empty_result() + + thresholds = self._get_engine_thresholds() + if not thresholds: + return CostThresholdResult( + exceeds_threshold=False, + estimated_cost=estimated_cost, + threshold_info={}, + formatted_warning=None, + ) + + return self._check_thresholds(estimated_cost, thresholds) + + def _create_empty_result(self) -> CostThresholdResult: + """Create an empty result when cost checking is disabled or fails.""" + return CostThresholdResult( + exceeds_threshold=False, + estimated_cost=[], + threshold_info={}, + formatted_warning=None, + ) + + def _get_estimated_cost(self) -> list[dict[str, Any]] | None: + """Get cost estimation, returning None if it fails.""" + try: + return self._estimation_command.run() + except Exception as ex: + logger.warning("Cost estimation failed: %s", str(ex)) + return None + + def _get_engine_thresholds(self) -> dict[str, Any]: + """Get thresholds for the current database engine.""" + database = self._estimation_command._database + engine_name = database.db_engine_spec.engine_name + if engine_name is None: + return {} + + engine_name = engine_name.lower() + return config.get("SQLLAB_QUERY_COST_THRESHOLDS", {}).get(engine_name, {}) + + def _check_thresholds( + self, estimated_cost: list[dict[str, Any]], thresholds: dict[str, Any] + ) -> CostThresholdResult: + """Check if estimated cost exceeds configured thresholds.""" + exceeds_threshold = False + warning_messages = [] + threshold_info = {} + + for cost_item in estimated_cost: + if self._check_bytes_threshold(cost_item, thresholds, threshold_info, warning_messages): + exceeds_threshold = True + if self._check_cost_threshold(cost_item, thresholds, threshold_info, warning_messages): + exceeds_threshold = True + + formatted_warning = None + if warning_messages: + formatted_warning = ( + " ".join(warning_messages) + " Are you sure you want to continue?" + ) + + return CostThresholdResult( + exceeds_threshold=exceeds_threshold, + estimated_cost=estimated_cost, + threshold_info=threshold_info, + formatted_warning=formatted_warning, + ) + + def _check_bytes_threshold( + self, + cost_item: dict[str, Any], + thresholds: dict[str, Any], + threshold_info: dict[str, Any], + warning_messages: list[str] + ) -> bool: + """Check bytes scanned threshold. Returns True if threshold exceeded.""" + if "bytes_scanned" not in thresholds or "Bytes Scanned" not in cost_item: + return False + + try: + bytes_scanned = self._parse_bytes_from_cost_item(cost_item["Bytes Scanned"]) + threshold_bytes = thresholds["bytes_scanned"] + threshold_info["bytes_threshold"] = threshold_bytes + threshold_info["estimated_bytes"] = bytes_scanned + + if bytes_scanned > threshold_bytes: + warning_messages.append( + f"This query will scan approximately {self._format_bytes(bytes_scanned)} " + f"of data, which exceeds the threshold of {self._format_bytes(threshold_bytes)}." + ) + return True + except (ValueError, KeyError) as ex: + logger.warning("Failed to parse bytes from cost estimation: %s", str(ex)) + + return False + + def _check_cost_threshold( + self, + cost_item: dict[str, Any], + thresholds: dict[str, Any], + threshold_info: dict[str, Any], + warning_messages: list[str] + ) -> bool: + """Check cost threshold. Returns True if threshold exceeded.""" + if "cost_threshold" not in thresholds or "Cost" not in cost_item: + return False + + try: + cost_value = float(cost_item["Cost"]) + threshold_cost = thresholds["cost_threshold"] + threshold_info["cost_threshold"] = threshold_cost + threshold_info["estimated_cost"] = cost_value + + if cost_value > threshold_cost: + warning_messages.append( + f"This query has an estimated cost of {cost_value}, " + f"which exceeds the threshold of {threshold_cost}." + ) + return True + except (ValueError, KeyError) as ex: + logger.warning("Failed to parse cost from cost estimation: %s", str(ex)) + + return False + + def _parse_bytes_from_cost_item(self, bytes_str: str) -> int: + """Parse bytes from formatted string like '5.2 GB' or '1024 MB'.""" + if not isinstance(bytes_str, str): + return int(bytes_str) + + # Remove commas and split + parts = bytes_str.replace(",", "").strip().split() + if len(parts) != 2: + raise ValueError(f"Cannot parse bytes from: {bytes_str}") + + value_str, unit = parts + value = float(value_str) + unit = unit.upper() + + multipliers = { + "B": 1, + "KB": 1024, + "MB": 1024**2, + "GB": 1024**3, + "TB": 1024**4, + "PB": 1024**5, + } + + if unit not in multipliers: + raise ValueError(f"Unknown unit: {unit}") + + return int(value * multipliers[unit]) + + def _format_bytes(self, bytes_count: int) -> str: + """Format bytes into human-readable string.""" + if bytes_count < 1024: + return f"{bytes_count} B" + elif bytes_count < 1024**2: + return f"{bytes_count / 1024:.1f} KB" + elif bytes_count < 1024**3: + return f"{bytes_count / (1024**2):.1f} MB" + elif bytes_count < 1024**4: + return f"{bytes_count / (1024**3):.1f} GB" + elif bytes_count < 1024**5: + return f"{bytes_count / (1024**4):.1f} TB" + else: + return f"{bytes_count / (1024**5):.1f} PB" diff --git a/superset/config.py b/superset/config.py index 273f2232da..73cb9e689e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1191,6 +1191,18 @@ SQLLAB_ASYNC_TIME_LIMIT_SEC = int(timedelta(hours=6).total_seconds()) # timeout. SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds()) +# Query cost governance configuration +# Enable automatic cost checking before query execution +SQLLAB_QUERY_COST_CHECKING_ENABLED = False + +# Cost thresholds that trigger warnings before query execution +# This is a dictionary where keys are database engine names and values are threshold configs +# Each threshold config can contain: +# - 'bytes_scanned': maximum bytes that can be scanned without warning +# - 'cost_threshold': monetary cost threshold (engine-specific units) +# Example: {'bigquery': {'bytes_scanned': 5 * 1024**4}, 'presto': {'cost_threshold': 1000}} +SQLLAB_QUERY_COST_THRESHOLDS = {} + # Timeout duration for SQL Lab fetching query results by the resultsKey. # 0 means no timeout. SQLLAB_QUERY_RESULT_TIMEOUT = 0 diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index cdb2436e58..472c4ae0c5 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -25,6 +25,9 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from marshmallow import ValidationError from superset import app, is_feature_enabled +from superset.commands.sql_lab.check_cost_threshold import ( + QueryCostThresholdCheckCommand, +) from superset.commands.sql_lab.estimate import QueryEstimationCommand from superset.commands.sql_lab.execute import CommandResult, ExecuteSqlCommand from superset.commands.sql_lab.export import SqlResultExportCommand @@ -188,6 +191,66 @@ class SqlLabRestApi(BaseSupersetApi): result = command.run() return self.response(200, result=result) + @expose("/check_cost_threshold/", methods=("POST",)) + @protect() + @statsd_metrics + @requires_json + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".check_cost_threshold", + log_to_statsd=False, + ) + def check_cost_threshold(self) -> Response: + """Check if query cost exceeds configured thresholds. + --- + post: + summary: Check if query cost exceeds thresholds + requestBody: + description: SQL query and params + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EstimateQueryCostSchema' + responses: + 200: + description: Cost threshold check result + content: + application/json: + schema: + type: object + properties: + exceeds_threshold: + type: boolean + description: Whether query exceeds cost thresholds + estimated_cost: + type: array + description: Detailed cost estimation + threshold_info: + type: object + description: Information about thresholds and estimates + formatted_warning: + type: string + nullable: true + description: Human-readable warning message + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 500: + $ref: '#/components/responses/500' + """ + try: + model = self.estimate_model_schema.load(request.json) + except ValidationError as error: + return self.response_400(message=error.messages) + + command = QueryCostThresholdCheckCommand(model) + result = command.run() + return self.response(200, **result) + @expose("/format_sql/", methods=("POST",)) @statsd_metrics @protect()
