This is an automated email from the ASF dual-hosted git repository. kgabryje pushed a commit to branch what-if in repository https://gitbox.apache.org/repos/asf/superset.git
commit fce4fc039faee41282c8bfa93fcf07d2a4f6c51d Author: Kamil Gabryjelski <[email protected]> AuthorDate: Thu Dec 18 11:35:15 2025 +0100 AI column modifications --- .../dashboard/components/WhatIfDrawer/index.tsx | 257 ++++++++++++++++++--- .../src/dashboard/components/WhatIfDrawer/types.ts | 38 +++ .../dashboard/components/WhatIfDrawer/whatIfApi.ts | 49 ++++ superset-frontend/src/dashboard/types.ts | 2 + superset-frontend/src/dashboard/util/whatIf.ts | 2 + superset/what_if/api.py | 73 ++++++ superset/what_if/commands/suggest_related.py | 220 ++++++++++++++++++ superset/what_if/schemas.py | 87 +++++++ 8 files changed, 701 insertions(+), 27 deletions(-) diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx b/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx index f25ae093c2..34d88300eb 100644 --- a/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx @@ -18,9 +18,14 @@ */ import { useCallback, useMemo, useState } from 'react'; import { useDispatch, useSelector } from 'react-redux'; -import { t } from '@superset-ui/core'; +import { t, logging } from '@superset-ui/core'; import { css, styled, Alert, useTheme } from '@apache-superset/core/ui'; -import { Button, Select } from '@superset-ui/core/components'; +import { + Button, + Select, + Checkbox, + Tooltip, +} from '@superset-ui/core/components'; import Slider from '@superset-ui/core/components/Slider'; import { Icons } from '@superset-ui/core/components/Icons'; import { setWhatIfModifications } from 'src/dashboard/actions/dashboardState'; @@ -31,6 +36,8 @@ import { import { getNumericColumnsForDashboard } from 'src/dashboard/util/whatIf'; import { RootState, Slice, WhatIfColumn } from 'src/dashboard/types'; import WhatIfAIInsights from './WhatIfAIInsights'; +import { fetchRelatedColumnSuggestions } from './whatIfApi'; +import { ExtendedWhatIfModification } from './types'; export const WHAT_IF_PANEL_WIDTH = 300; @@ -115,6 +122,69 @@ const SliderContainer = styled.div` const ApplyButton = styled(Button)` width: 100%; + min-height: 32px; +`; + +const CheckboxContainer = styled.div` + display: flex; + align-items: center; + gap: ${({ theme }) => theme.sizeUnit}px; +`; + +const ModificationsSection = styled.div` + display: flex; + flex-direction: column; + gap: ${({ theme }) => theme.sizeUnit * 2}px; +`; + +const ModificationsSectionTitle = styled.div` + font-weight: ${({ theme }) => theme.fontWeightStrong}; + color: ${({ theme }) => theme.colorText}; + font-size: ${({ theme }) => theme.fontSizeSM}px; +`; + +const ModificationCard = styled.div<{ isAISuggested?: boolean }>` + padding: ${({ theme }) => theme.sizeUnit * 2}px; + background-color: ${({ theme, isAISuggested }) => + isAISuggested ? theme.colorInfoBg : theme.colorBgLayout}; + border: 1px solid + ${({ theme, isAISuggested }) => + isAISuggested ? theme.colorInfoBorder : theme.colorBorderSecondary}; + border-radius: ${({ theme }) => theme.borderRadius}px; +`; + +const ModificationHeader = styled.div` + display: flex; + align-items: center; + justify-content: space-between; + gap: ${({ theme }) => theme.sizeUnit}px; +`; + +const ModificationColumn = styled.span` + font-weight: ${({ theme }) => theme.fontWeightStrong}; + color: ${({ theme }) => theme.colorText}; +`; + +const ModificationValue = styled.span<{ isPositive: boolean }>` + font-weight: ${({ theme }) => theme.fontWeightStrong}; + color: ${({ theme, isPositive }) => + isPositive ? theme.colorSuccess : theme.colorError}; +`; + +const AIBadge = styled.span` + font-size: ${({ theme }) => theme.fontSizeXS}px; + padding: 2px 6px; + background-color: ${({ theme }) => theme.colorInfo}; + color: ${({ theme }) => theme.colorWhite}; + border-radius: ${({ theme }) => theme.borderRadius}px; + font-weight: ${({ theme }) => theme.fontWeightStrong}; +`; + +const ModificationReasoning = styled.div` + font-size: ${({ theme }) => theme.fontSizeSM}px; + color: ${({ theme }) => theme.colorTextSecondary}; + margin-top: ${({ theme }) => theme.sizeUnit}px; + font-style: italic; `; interface WhatIfPanelProps { @@ -129,6 +199,11 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { const [selectedColumn, setSelectedColumn] = useState<string | null>(null); const [sliderValue, setSliderValue] = useState<number>(SLIDER_DEFAULT); const [affectedChartIds, setAffectedChartIds] = useState<number[]>([]); + const [enableCascadingEffects, setEnableCascadingEffects] = useState(false); + const [isLoadingSuggestions, setIsLoadingSuggestions] = useState(false); + const [appliedModifications, setAppliedModifications] = useState< + ExtendedWhatIfModification[] + >([]); const slices = useSelector( (state: RootState) => state.sliceEntities.slices as { [id: number]: Slice }, @@ -166,41 +241,112 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { setSliderValue(value); }, []); - const handleApply = useCallback(() => { + const dashboardInfo = useSelector((state: RootState) => state.dashboardInfo); + + const handleApply = useCallback(async () => { if (!selectedColumn) return; const multiplier = 1 + sliderValue / 100; - // Get affected chart IDs - const chartIds = columnToChartIds.get(selectedColumn) || []; + // Base user modification + const userModification: ExtendedWhatIfModification = { + column: selectedColumn, + multiplier, + isAISuggested: false, + }; + + let allModifications: ExtendedWhatIfModification[] = [userModification]; + + // If cascading effects enabled, fetch AI suggestions + if (enableCascadingEffects) { + setIsLoadingSuggestions(true); + try { + const suggestions = await fetchRelatedColumnSuggestions({ + selectedColumn, + userMultiplier: multiplier, + availableColumns: numericColumns.map(col => ({ + columnName: col.columnName, + description: col.description, + verboseName: col.verboseName, + datasourceId: col.datasourceId, + })), + dashboardName: dashboardInfo?.dash_edit_perm + ? dashboardInfo?.dashboard_title + : undefined, + }); + + // Add AI suggestions to modifications + const aiModifications: ExtendedWhatIfModification[] = + suggestions.suggestedModifications.map(mod => ({ + column: mod.column, + multiplier: mod.multiplier, + isAISuggested: true, + reasoning: mod.reasoning, + confidence: mod.confidence, + })); + + allModifications = [...allModifications, ...aiModifications]; + } catch (error) { + logging.error('Failed to get AI suggestions:', error); + // Continue with just user modification + } + setIsLoadingSuggestions(false); + } + + setAppliedModifications(allModifications); + + // Collect all affected chart IDs from all modifications + const allAffectedChartIds = new Set<number>(); + allModifications.forEach(mod => { + const chartIds = columnToChartIds.get(mod.column) || []; + chartIds.forEach(id => allAffectedChartIds.add(id)); + }); + const chartIdsArray = Array.from(allAffectedChartIds); // Save affected chart IDs for AI insights - setAffectedChartIds(chartIds); + setAffectedChartIds(chartIdsArray); // Save original chart data before applying what-if modifications - chartIds.forEach(chartId => { + chartIdsArray.forEach(chartId => { dispatch(saveOriginalChartData(chartId)); }); - // Set the what-if modifications in Redux state + // Set the what-if modifications in Redux state (all modifications) dispatch( - setWhatIfModifications([ - { - column: selectedColumn, - multiplier, - }, - ]), + setWhatIfModifications( + allModifications.map(mod => ({ + column: mod.column, + multiplier: mod.multiplier, + filters: mod.filters, + })), + ), ); - // Trigger queries for all charts that use the selected column - chartIds.forEach(chartId => { + // Trigger queries for all affected charts + chartIdsArray.forEach(chartId => { dispatch(triggerQuery(true, chartId)); }); - }, [dispatch, selectedColumn, sliderValue, columnToChartIds]); - - const isApplyDisabled = !selectedColumn || sliderValue === SLIDER_DEFAULT; + }, [ + dispatch, + selectedColumn, + sliderValue, + columnToChartIds, + enableCascadingEffects, + numericColumns, + dashboardInfo, + ]); + + const isApplyDisabled = + !selectedColumn || sliderValue === SLIDER_DEFAULT || isLoadingSuggestions; const isSliderDisabled = !selectedColumn; + // Helper to format percentage change + const formatPercentage = (multiplier: number): string => { + const pct = (multiplier - 1) * 100; + const sign = pct >= 0 ? '+' : ''; + return `${sign}${pct.toFixed(1)}%`; + }; + const sliderMarks = { [SLIDER_MIN]: `${SLIDER_MIN}%`, 0: '0%', @@ -255,22 +401,79 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { </SliderContainer> </FormSection> + <CheckboxContainer> + <Checkbox + checked={enableCascadingEffects} + onChange={e => setEnableCascadingEffects(e.target.checked)} + > + {t('AI-powered cascading effects')} + </Checkbox> + <Tooltip + title={t( + 'When enabled, AI will analyze column relationships and automatically suggest related columns that should also be modified.', + )} + > + <Icons.InfoCircleOutlined + iconSize="s" + css={css` + color: ${theme.colorTextSecondary}; + cursor: help; + `} + /> + </Tooltip> + </CheckboxContainer> + <ApplyButton buttonStyle="primary" onClick={handleApply} disabled={isApplyDisabled} + loading={isLoadingSuggestions} > <Icons.StarFilled iconSize="s" /> - {t('See what if')} + {isLoadingSuggestions + ? t('Analyzing relationships...') + : t('See what if')} </ApplyButton> - <Alert - type="info" - message={t( - 'Select a column above to simulate changes and preview how it would impact your dashboard in real-time.', - )} - showIcon - /> + {appliedModifications.length === 0 && ( + <Alert + type="info" + message={t( + 'Select a column above to simulate changes and preview how it would impact your dashboard in real-time.', + )} + showIcon + /> + )} + + {appliedModifications.length > 0 && ( + <ModificationsSection> + <ModificationsSectionTitle> + {t('Applied Modifications')} + </ModificationsSectionTitle> + {appliedModifications.map((mod, idx) => ( + <ModificationCard key={idx} isAISuggested={mod.isAISuggested}> + <ModificationHeader> + <ModificationColumn>{mod.column}</ModificationColumn> + <div + css={css` + display: flex; + align-items: center; + gap: ${theme.sizeUnit}px; + `} + > + <ModificationValue isPositive={mod.multiplier >= 1}> + {formatPercentage(mod.multiplier)} + </ModificationValue> + {mod.isAISuggested && <AIBadge>{t('AI')}</AIBadge>} + </div> + </ModificationHeader> + {mod.reasoning && ( + <ModificationReasoning>{mod.reasoning}</ModificationReasoning> + )} + </ModificationCard> + ))} + </ModificationsSection> + )} {affectedChartIds.length > 0 && ( <WhatIfAIInsights affectedChartIds={affectedChartIds} /> diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts b/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts index 029f336fad..2c1ea2a02d 100644 --- a/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts @@ -71,3 +71,41 @@ export interface WhatIfInterpretResponse { } export type WhatIfAIStatus = 'idle' | 'loading' | 'success' | 'error'; + +// Types for suggest_related endpoint + +export interface AvailableColumn { + columnName: string; + description?: string | null; + verboseName?: string | null; + datasourceId: number; +} + +export interface SuggestedModification { + column: string; + multiplier: number; + reasoning: string; + confidence: 'high' | 'medium' | 'low'; +} + +export interface WhatIfSuggestRelatedRequest { + selectedColumn: string; + userMultiplier: number; + availableColumns: AvailableColumn[]; + dashboardName?: string; +} + +export interface WhatIfSuggestRelatedResponse { + suggestedModifications: SuggestedModification[]; + explanation?: string; +} + +// Extended modification type that tracks whether it came from AI +export interface ExtendedWhatIfModification { + column: string; + multiplier: number; + filters?: WhatIfFilter[]; + isAISuggested?: boolean; + reasoning?: string; + confidence?: 'high' | 'medium' | 'low'; +} diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts b/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts index 5b03e55f97..64f51601cb 100644 --- a/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts @@ -23,6 +23,9 @@ import { WhatIfInterpretResponse, ChartComparison, WhatIfFilter, + WhatIfSuggestRelatedRequest, + WhatIfSuggestRelatedResponse, + SuggestedModification, } from './types'; interface ApiResponse { @@ -84,3 +87,49 @@ export async function fetchWhatIfInterpretation( rawResponse: result.raw_response, }; } + +interface ApiSuggestRelatedResponse { + result: { + suggested_modifications: Array<{ + column: string; + multiplier: number; + reasoning: string; + confidence: string; + }>; + explanation?: string; + }; +} + +export async function fetchRelatedColumnSuggestions( + request: WhatIfSuggestRelatedRequest, +): Promise<WhatIfSuggestRelatedResponse> { + const response = await SupersetClient.post({ + endpoint: '/api/v1/what_if/suggest_related', + jsonPayload: { + selected_column: request.selectedColumn, + user_multiplier: request.userMultiplier, + available_columns: request.availableColumns.map(col => ({ + column_name: col.columnName, + description: col.description, + verbose_name: col.verboseName, + datasource_id: col.datasourceId, + })), + dashboard_name: request.dashboardName, + }, + }); + + const data = response.json as ApiSuggestRelatedResponse; + const { result } = data; + + return { + suggestedModifications: result.suggested_modifications.map( + (mod): SuggestedModification => ({ + column: mod.column, + multiplier: mod.multiplier, + reasoning: mod.reasoning, + confidence: mod.confidence as 'high' | 'medium' | 'low', + }), + ), + explanation: result.explanation, + }; +} diff --git a/superset-frontend/src/dashboard/types.ts b/superset-frontend/src/dashboard/types.ts index 553b26f1a6..48cce50f45 100644 --- a/superset-frontend/src/dashboard/types.ts +++ b/superset-frontend/src/dashboard/types.ts @@ -312,6 +312,8 @@ export interface WhatIfColumn { columnName: string; datasourceId: number; usedByChartIds: number[]; + description?: string | null; + verboseName?: string | null; } export enum MenuKeys { diff --git a/superset-frontend/src/dashboard/util/whatIf.ts b/superset-frontend/src/dashboard/util/whatIf.ts index da3742bc66..cfe590899f 100644 --- a/superset-frontend/src/dashboard/util/whatIf.ts +++ b/superset-frontend/src/dashboard/util/whatIf.ts @@ -154,6 +154,8 @@ export function getNumericColumnsForDashboard( columnName: colName, datasourceId: datasource.id, usedByChartIds: [chartId], + description: colMetadata.description, + verboseName: colMetadata.verbose_name, }); } else { const existing = columnMap.get(key)!; diff --git a/superset/what_if/api.py b/superset/what_if/api.py index 62173022ea..5ed560b0cf 100644 --- a/superset/what_if/api.py +++ b/superset/what_if/api.py @@ -28,10 +28,13 @@ from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP from superset.extensions import event_logger from superset.views.base_api import BaseSupersetApi, statsd_metrics from superset.what_if.commands.interpret import WhatIfInterpretCommand +from superset.what_if.commands.suggest_related import WhatIfSuggestRelatedCommand from superset.what_if.exceptions import OpenRouterAPIError, OpenRouterConfigError from superset.what_if.schemas import ( WhatIfInterpretRequestSchema, WhatIfInterpretResponseSchema, + WhatIfSuggestRelatedRequestSchema, + WhatIfSuggestRelatedResponseSchema, ) logger = logging.getLogger(__name__) @@ -116,3 +119,73 @@ class WhatIfRestApi(BaseSupersetApi): except ValueError as ex: logger.warning("Invalid request: %s", ex) return self.response_400(message=str(ex)) + + @expose("/suggest_related", methods=("POST",)) + @event_logger.log_this + @protect() + @safe + @statsd_metrics + def suggest_related(self) -> Response: + """Get AI suggestions for related column modifications. + --- + post: + summary: Get AI-suggested cascading column modifications + description: >- + Analyzes column relationships and suggests related columns + that should be modified when a user modifies a specific column. + Uses AI to infer causal, mathematical, and domain-specific relationships. + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/WhatIfSuggestRelatedRequestSchema' + responses: + 200: + description: Related column suggestions generated successfully + content: + application/json: + schema: + type: object + properties: + result: + $ref: '#/components/schemas/WhatIfSuggestRelatedResponseSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + 502: + description: Error communicating with AI service + content: + application/json: + schema: + type: object + properties: + message: + type: string + security: + - jwt: [] + """ + try: + request_data = WhatIfSuggestRelatedRequestSchema().load(request.json) + except ValidationError as ex: + logger.warning("Invalid request data: %s", ex.messages) + return self.response_400(message=str(ex.messages)) + + try: + command = WhatIfSuggestRelatedCommand(request_data) + result = command.run() + return self.response( + 200, result=WhatIfSuggestRelatedResponseSchema().dump(result) + ) + except OpenRouterConfigError as ex: + logger.error("OpenRouter configuration error: %s", ex) + return self.response(500, message="AI suggestions are not configured") + except OpenRouterAPIError as ex: + logger.error("OpenRouter API error: %s", ex) + return self.response(502, message=str(ex)) + except ValueError as ex: + logger.warning("Invalid request: %s", ex) + return self.response_400(message=str(ex)) diff --git a/superset/what_if/commands/suggest_related.py b/superset/what_if/commands/suggest_related.py new file mode 100644 index 0000000000..7cd5334a70 --- /dev/null +++ b/superset/what_if/commands/suggest_related.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis suggest related columns command using OpenRouter.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import httpx +from flask import current_app + +from superset.commands.base import BaseCommand +from superset.what_if.exceptions import OpenRouterAPIError, OpenRouterConfigError + +logger = logging.getLogger(__name__) + + +class WhatIfSuggestRelatedCommand(BaseCommand): + """Command to get AI suggestions for related column modifications.""" + + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + def run(self) -> dict[str, Any]: + self.validate() + return self._get_ai_suggestions() + + def validate(self) -> None: + api_key = current_app.config.get("OPENROUTER_API_KEY") + if not api_key: + raise OpenRouterConfigError("OPENROUTER_API_KEY not configured") + + if not self._data.get("selected_column"): + raise ValueError("selected_column is required") + + if not self._data.get("available_columns"): + raise ValueError("available_columns list is required") + + if self._data.get("user_multiplier") is None: + raise ValueError("user_multiplier is required") + + def _build_prompt(self) -> str: + selected_column = self._data["selected_column"] + user_multiplier = self._data["user_multiplier"] + available_columns = self._data["available_columns"] + dashboard_name = self._data.get("dashboard_name") or "Dashboard" + + pct_change = (user_multiplier - 1) * 100 + sign = "+" if pct_change >= 0 else "" + + # Build column list with descriptions + columns_text = [] + for col in available_columns: + # Skip the selected column - we don't want to suggest modifying it again + if col["column_name"] == selected_column: + continue + + col_desc = f"- **{col['column_name']}**" + if col.get("verbose_name"): + col_desc += f" ({col['verbose_name']})" + if col.get("description"): + col_desc += f": {col['description']}" + columns_text.append(col_desc) + + if not columns_text: + columns_text = ["No other columns available"] + + return f"""You are a business intelligence analyst helping with what-if scenario analysis. + +## Context +A user is working on a "{dashboard_name}" dashboard and wants to simulate the cascading effects of changing a metric. + +## User's Modification +The user is modifying **{selected_column}** by {sign}{pct_change:.1f}% + +## Other Available Columns +These are the other numeric columns available in the dashboard: +{chr(10).join(columns_text)} + +## Your Task +Analyze the relationships between these columns and suggest which OTHER columns should also be modified as a cascading effect of the user's change to {selected_column}. + +Consider: +1. **Causal relationships**: If column A affects column B in real business scenarios +2. **Mathematical relationships**: Derived metrics, ratios, calculated fields +3. **Domain knowledge**: Industry-standard relationships (e.g., increasing customers often increases orders and revenue) + +For each suggested column, provide: +- The appropriate multiplier (proportional, dampened, amplified, or inverse based on the relationship) +- A brief reasoning explaining the relationship (1 sentence) +- Your confidence level (high/medium/low) + +Guidelines: +- Only suggest columns that have a clear logical relationship to {selected_column} +- Be conservative - don't suggest modifications without good reasoning +- The multiplier should be realistic (e.g., if {selected_column} increases 10%, a related column might increase 5-15%, not 100%) +- If no clear relationships exist, return an empty suggestions array + +Respond in JSON format: +{{ + "suggested_modifications": [ + {{ + "column": "column_name", + "multiplier": 1.08, + "reasoning": "Brief explanation of the relationship", + "confidence": "high" + }} + ], + "explanation": "Overall summary of the analysis (1-2 sentences)" +}}""" + + def _get_ai_suggestions(self) -> dict[str, Any]: + api_key = current_app.config.get("OPENROUTER_API_KEY") + model = current_app.config.get("OPENROUTER_MODEL", "x-ai/grok-4.1-fast") + api_base = current_app.config.get( + "OPENROUTER_API_BASE", "https://openrouter.ai/api/v1" + ) + timeout = current_app.config.get("OPENROUTER_TIMEOUT", 30) + + prompt = self._build_prompt() + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "HTTP-Referer": current_app.config.get("WEBDRIVER_BASEURL", ""), + "X-Title": "Apache Superset What-If Analysis", + } + + payload = { + "model": model, + "messages": [ + { + "role": "system", + "content": ( + "You are a business intelligence analyst specializing in " + "data relationships and cascading effects analysis. " + "Respond only with valid JSON." + ), + }, + {"role": "user", "content": prompt}, + ], + "temperature": 0.3, + "max_tokens": 1000, + "response_format": {"type": "json_object"}, + } + + try: + with httpx.Client(timeout=timeout) as client: + response = client.post( + f"{api_base}/chat/completions", + headers=headers, + json=payload, + ) + response.raise_for_status() + + result = response.json() + content = result["choices"][0]["message"]["content"] + + # Parse the JSON response + parsed = json.loads(content) + + # Validate and normalize the response + suggestions = parsed.get("suggested_modifications", []) + validated_suggestions = [] + + for suggestion in suggestions: + # Ensure required fields exist + if all( + k in suggestion + for k in ["column", "multiplier", "reasoning", "confidence"] + ): + # Normalize confidence to lowercase + confidence = suggestion["confidence"].lower() + if confidence not in ("high", "medium", "low"): + confidence = "medium" + + validated_suggestions.append( + { + "column": suggestion["column"], + "multiplier": float(suggestion["multiplier"]), + "reasoning": suggestion["reasoning"], + "confidence": confidence, + } + ) + + return { + "suggested_modifications": validated_suggestions, + "explanation": parsed.get("explanation"), + } + + except httpx.HTTPStatusError as ex: + logger.error("OpenRouter API error: %s", ex.response.status_code) + raise OpenRouterAPIError( + f"OpenRouter API error: {ex.response.status_code}" + ) from ex + except json.JSONDecodeError as ex: + logger.error("Failed to parse AI response: %s", ex) + raise OpenRouterAPIError("Failed to parse AI response") from ex + except httpx.TimeoutException as ex: + logger.error("OpenRouter API timeout") + raise OpenRouterAPIError("AI service timed out") from ex + except Exception as ex: + logger.exception("Unexpected error calling OpenRouter") + raise OpenRouterAPIError(f"Unexpected error: {ex!s}") from ex diff --git a/superset/what_if/schemas.py b/superset/what_if/schemas.py index df0f4ad2b4..87212daede 100644 --- a/superset/what_if/schemas.py +++ b/superset/what_if/schemas.py @@ -161,3 +161,90 @@ class WhatIfInterpretResponseSchema(Schema): required=False, metadata={"description": "Raw AI response (only in debug mode)"}, ) + + +# Schemas for suggest_related endpoint + + +class AvailableColumnSchema(Schema): + """Schema for an available column with metadata.""" + + column_name = fields.String( + required=True, + metadata={"description": "Name of the column"}, + ) + description = fields.String( + required=False, + load_default=None, + metadata={"description": "Column description/documentation"}, + ) + verbose_name = fields.String( + required=False, + load_default=None, + metadata={"description": "Human-readable column name"}, + ) + datasource_id = fields.Integer( + required=True, + metadata={"description": "ID of the datasource containing this column"}, + ) + + +class WhatIfSuggestRelatedRequestSchema(Schema): + """Schema for suggest_related request.""" + + selected_column = fields.String( + required=True, + metadata={"description": "The column the user selected to modify"}, + ) + user_multiplier = fields.Float( + required=True, + metadata={ + "description": "The multiplier the user applied (e.g., 1.1 for +10%)" + }, + ) + available_columns = fields.List( + fields.Nested(AvailableColumnSchema), + required=True, + metadata={"description": "All numeric columns available in the dashboard"}, + ) + dashboard_name = fields.String( + required=False, + load_default=None, + metadata={"description": "Name of the dashboard for context"}, + ) + + +class SuggestedModificationSchema(Schema): + """Schema for a single AI-suggested modification.""" + + column = fields.String( + required=True, + metadata={"description": "Column name to modify"}, + ) + multiplier = fields.Float( + required=True, + metadata={"description": "Suggested multiplier for this column"}, + ) + reasoning = fields.String( + required=True, + metadata={"description": "Brief explanation of why this column is related"}, + ) + confidence = fields.String( + required=True, + metadata={"description": "Confidence level: high, medium, or low"}, + ) + + +class WhatIfSuggestRelatedResponseSchema(Schema): + """Schema for suggest_related response.""" + + suggested_modifications = fields.List( + fields.Nested(SuggestedModificationSchema), + required=True, + metadata={"description": "List of AI-suggested column modifications"}, + ) + explanation = fields.String( + required=False, + load_default=None, + metadata={"description": "Overall explanation of the relationship analysis"}, + )
