This is an automated email from the ASF dual-hosted git repository. kgabryje pushed a commit to branch what-if in repository https://gitbox.apache.org/repos/asf/superset.git
commit 4a1471aef5c1a5c9f7480f624d75cfa936abb549 Author: Kamil Gabryjelski <[email protected]> AuthorDate: Thu Dec 18 10:28:51 2025 +0100 AI insights --- .../components/WhatIfDrawer/WhatIfAIInsights.tsx | 272 ++++++++++++++++++ .../dashboard/components/WhatIfDrawer/index.tsx | 15 +- .../src/dashboard/components/WhatIfDrawer/types.ts | 73 +++++ .../components/WhatIfDrawer/useChartComparison.ts | 314 +++++++++++++++++++++ .../dashboard/components/WhatIfDrawer/whatIfApi.ts | 86 ++++++ superset-frontend/src/dashboard/types.ts | 18 ++ superset/config.py | 12 + superset/initialization/__init__.py | 2 + superset/what_if/__init__.py | 17 ++ superset/what_if/api.py | 118 ++++++++ superset/what_if/commands/__init__.py | 17 ++ superset/what_if/commands/interpret.py | 212 ++++++++++++++ superset/what_if/exceptions.py | 37 +++ superset/what_if/schemas.py | 163 +++++++++++ 14 files changed, 1353 insertions(+), 3 deletions(-) diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/WhatIfAIInsights.tsx b/superset-frontend/src/dashboard/components/WhatIfDrawer/WhatIfAIInsights.tsx new file mode 100644 index 0000000000..8c276b7643 --- /dev/null +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/WhatIfAIInsights.tsx @@ -0,0 +1,272 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { useCallback, useEffect, useRef, useState } from 'react'; +import { useSelector } from 'react-redux'; +import { t } from '@superset-ui/core'; +import { styled, Alert } from '@apache-superset/core/ui'; +import { Icons } from '@superset-ui/core/components/Icons'; +import { Skeleton } from '@superset-ui/core/components/'; +import { RootState, WhatIfModification } from 'src/dashboard/types'; +import { fetchWhatIfInterpretation } from './whatIfApi'; +import { useChartComparison, useAllChartsLoaded } from './useChartComparison'; +import { + WhatIfAIStatus, + WhatIfInsight, + WhatIfInterpretResponse, +} from './types'; + +/** + * Create a stable key from modifications for comparison. + * This allows us to detect when modifications have meaningfully changed. + */ +function getModificationsKey(modifications: WhatIfModification[]): string { + return modifications + .map(m => `${m.column}:${m.multiplier}`) + .sort() + .join('|'); +} + +const InsightsContainer = styled.div` + display: flex; + flex-direction: column; + gap: ${({ theme }) => theme.sizeUnit * 3}px; + margin-top: ${({ theme }) => theme.sizeUnit * 4}px; + padding-top: ${({ theme }) => theme.sizeUnit * 4}px; + border-top: 1px solid ${({ theme }) => theme.colorBorderSecondary}; +`; + +const InsightsHeader = styled.div` + display: flex; + align-items: center; + gap: ${({ theme }) => theme.sizeUnit * 2}px; + font-weight: ${({ theme }) => theme.fontWeightStrong}; + color: ${({ theme }) => theme.colorText}; +`; + +const InsightCard = styled.div<{ insightType: string }>` + padding: ${({ theme }) => theme.sizeUnit * 3}px; + background-color: ${({ theme, insightType }) => { + switch (insightType) { + case 'observation': + return theme.colorInfoBg; + case 'implication': + return theme.colorWarningBg; + case 'recommendation': + return theme.colorSuccessBg; + default: + return theme.colorBgElevated; + } + }}; + border-radius: ${({ theme }) => theme.borderRadius}px; + border-left: 3px solid + ${({ theme, insightType }) => { + switch (insightType) { + case 'observation': + return theme.colorInfo; + case 'implication': + return theme.colorWarning; + case 'recommendation': + return theme.colorSuccess; + default: + return theme.colorBorder; + } + }}; +`; + +const InsightTitle = styled.div` + font-weight: ${({ theme }) => theme.fontWeightStrong}; + margin-bottom: ${({ theme }) => theme.sizeUnit}px; +`; + +const InsightDescription = styled.div` + color: ${({ theme }) => theme.colorTextSecondary}; + font-size: ${({ theme }) => theme.fontSizeSM}px; + line-height: 1.5; +`; + +const Summary = styled.div` + font-size: ${({ theme }) => theme.fontSize}px; + line-height: 1.6; + color: ${({ theme }) => theme.colorText}; + padding: ${({ theme }) => theme.sizeUnit * 3}px; + background-color: ${({ theme }) => theme.colorBgElevated}; + border-radius: ${({ theme }) => theme.borderRadius}px; +`; + +interface WhatIfAIInsightsProps { + affectedChartIds: number[]; +} + +const WhatIfAIInsights = ({ affectedChartIds }: WhatIfAIInsightsProps) => { + const [status, setStatus] = useState<WhatIfAIStatus>('idle'); + const [response, setResponse] = useState<WhatIfInterpretResponse | null>( + null, + ); + const [error, setError] = useState<string | null>(null); + + const whatIfModifications = useSelector<RootState, WhatIfModification[]>( + state => state.dashboardState.whatIfModifications ?? [], + ); + + const dashboardTitle = useSelector<RootState, string>( + // @ts-ignore + state => state.dashboardInfo?.dashboard_title || 'Dashboard', + ); + + const chartComparisons = useChartComparison(affectedChartIds); + const allChartsLoaded = useAllChartsLoaded(affectedChartIds); + + // Track modification changes to reset status when user adjusts the slider + const modificationsKey = getModificationsKey(whatIfModifications); + const prevModificationsKeyRef = useRef<string>(modificationsKey); + + // Debug logging + console.log('[WhatIfAIInsights] State:', { + affectedChartIds, + allChartsLoaded, + chartComparisonsLength: chartComparisons.length, + whatIfModificationsLength: whatIfModifications.length, + status, + modificationsKey, + willTriggerFetch: + whatIfModifications.length > 0 && + chartComparisons.length > 0 && + allChartsLoaded && + status === 'idle', + }); + + // Reset status when modifications change (user adjusts the slider) + useEffect(() => { + if ( + modificationsKey !== prevModificationsKeyRef.current && + whatIfModifications.length > 0 + ) { + console.log( + '[WhatIfAIInsights] Modifications changed, resetting status to idle', + ); + // eslint-disable-next-line react-hooks/set-state-in-effect -- Intentional: resetting state when modifications change + setStatus('idle'); + setResponse(null); + prevModificationsKeyRef.current = modificationsKey; + } + }, [modificationsKey, whatIfModifications.length]); + + const fetchInsights = useCallback(async () => { + if (whatIfModifications.length === 0 || chartComparisons.length === 0) { + return; + } + + setStatus('loading'); + setError(null); + + try { + const result = await fetchWhatIfInterpretation({ + modifications: whatIfModifications, + charts: chartComparisons, + dashboardName: dashboardTitle, + }); + setResponse(result); + setStatus('success'); + } catch (err) { + setError( + err instanceof Error + ? err.message + : t('Failed to generate AI insights'), + ); + setStatus('error'); + } + }, [whatIfModifications, chartComparisons, dashboardTitle]); + + // Automatically fetch insights when all affected charts have finished loading. + // We wait for allChartsLoaded to prevent race conditions where we'd send + // stale data before charts have re-queried with the what-if modifications. + // The setState call here is intentional - we're synchronizing with Redux state changes. + useEffect(() => { + if ( + whatIfModifications.length > 0 && + chartComparisons.length > 0 && + allChartsLoaded && + status === 'idle' + ) { + // eslint-disable-next-line react-hooks/set-state-in-effect -- Intentional: triggering async fetch based on Redux state + fetchInsights(); + } + }, [ + whatIfModifications, + chartComparisons, + allChartsLoaded, + status, + fetchInsights, + ]); + + // Reset state when modifications are cleared. + // The setState calls here are intentional - we're resetting local state when Redux state changes. + useEffect(() => { + if (whatIfModifications.length === 0) { + // eslint-disable-next-line react-hooks/set-state-in-effect -- Intentional: resetting state when Redux modifications cleared + setStatus('idle'); + setResponse(null); + setError(null); + } + }, [whatIfModifications]); + + if (whatIfModifications.length === 0) { + return null; + } + + return ( + <InsightsContainer data-test="what-if-ai-insights"> + <InsightsHeader> + <Icons.BulbOutlined iconSize="m" /> + {t('AI Insights')} + </InsightsHeader> + + {status === 'loading' && <Skeleton active paragraph={{ rows: 3 }} />} + + {status === 'error' && ( + <Alert + type="error" + message={t('Failed to generate insights')} + description={error} + showIcon + /> + )} + + {status === 'success' && response && ( + <> + <Summary>{response.summary}</Summary> + + {response.insights.map((insight: WhatIfInsight, index: number) => ( + <InsightCard key={index} insightType={insight.type}> + <InsightTitle>{insight.title}</InsightTitle> + <InsightDescription>{insight.description}</InsightDescription> + </InsightCard> + ))} + </> + )} + + {status === 'idle' && !allChartsLoaded && ( + <Skeleton active paragraph={{ rows: 2 }} /> + )} + </InsightsContainer> + ); +}; + +export default WhatIfAIInsights; diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx b/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx index 68a6163a68..f25ae093c2 100644 --- a/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/index.tsx @@ -30,6 +30,7 @@ import { } from 'src/components/Chart/chartAction'; import { getNumericColumnsForDashboard } from 'src/dashboard/util/whatIf'; import { RootState, Slice, WhatIfColumn } from 'src/dashboard/types'; +import WhatIfAIInsights from './WhatIfAIInsights'; export const WHAT_IF_PANEL_WIDTH = 300; @@ -127,6 +128,7 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { const [selectedColumn, setSelectedColumn] = useState<string | null>(null); const [sliderValue, setSliderValue] = useState<number>(SLIDER_DEFAULT); + const [affectedChartIds, setAffectedChartIds] = useState<number[]>([]); const slices = useSelector( (state: RootState) => state.sliceEntities.slices as { [id: number]: Slice }, @@ -170,10 +172,13 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { const multiplier = 1 + sliderValue / 100; // Get affected chart IDs - const affectedChartIds = columnToChartIds.get(selectedColumn) || []; + const chartIds = columnToChartIds.get(selectedColumn) || []; + + // Save affected chart IDs for AI insights + setAffectedChartIds(chartIds); // Save original chart data before applying what-if modifications - affectedChartIds.forEach(chartId => { + chartIds.forEach(chartId => { dispatch(saveOriginalChartData(chartId)); }); @@ -188,7 +193,7 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { ); // Trigger queries for all charts that use the selected column - affectedChartIds.forEach(chartId => { + chartIds.forEach(chartId => { dispatch(triggerQuery(true, chartId)); }); }, [dispatch, selectedColumn, sliderValue, columnToChartIds]); @@ -266,6 +271,10 @@ const WhatIfPanel = ({ onClose, topOffset }: WhatIfPanelProps) => { )} showIcon /> + + {affectedChartIds.length > 0 && ( + <WhatIfAIInsights affectedChartIds={affectedChartIds} /> + )} </PanelContent> </PanelContainer> ); diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts b/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts new file mode 100644 index 0000000000..029f336fad --- /dev/null +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/types.ts @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export interface ChartMetricComparison { + metricName: string; + originalValue: number; + modifiedValue: number; + percentageChange: number; +} + +export interface ChartComparison { + chartId: number; + chartName: string; + chartType: string; + metrics: ChartMetricComparison[]; +} + +export type WhatIfFilterOperator = + | '==' + | '!=' + | '>' + | '<' + | '>=' + | '<=' + | 'IN' + | 'NOT IN' + | 'TEMPORAL_RANGE'; + +export interface WhatIfFilter { + col: string; + op: WhatIfFilterOperator; + val: string | number | boolean | Array<string | number>; +} + +export interface WhatIfInterpretRequest { + modifications: Array<{ + column: string; + multiplier: number; + filters?: WhatIfFilter[]; + }>; + charts: ChartComparison[]; + dashboardName?: string; +} + +export interface WhatIfInsight { + title: string; + description: string; + type: 'observation' | 'implication' | 'recommendation'; +} + +export interface WhatIfInterpretResponse { + summary: string; + insights: WhatIfInsight[]; + rawResponse?: string; +} + +export type WhatIfAIStatus = 'idle' | 'loading' | 'success' | 'error'; diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/useChartComparison.ts b/superset-frontend/src/dashboard/components/WhatIfDrawer/useChartComparison.ts new file mode 100644 index 0000000000..f3c2d035a1 --- /dev/null +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/useChartComparison.ts @@ -0,0 +1,314 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { useCallback, useMemo } from 'react'; +import { shallowEqual, useSelector } from 'react-redux'; +import { QueryData } from '@superset-ui/core'; +import { + ActiveTabs, + DashboardLayout, + RootState, + Slice, +} from 'src/dashboard/types'; +import { ChartComparison, ChartMetricComparison } from './types'; +import { CHART_TYPE, TAB_TYPE } from 'src/dashboard/util/componentTypes'; + +interface ChartStateWithOriginal { + chartStatus?: string; + queriesResponse?: QueryData[] | null; + originalQueriesResponse?: QueryData[] | null; +} + +interface QueryResponse { + data?: Array<Record<string, unknown>>; + colnames?: string[]; +} + +function extractMetricValue( + data: Array<Record<string, unknown>> | undefined, + metricName: string, +): number | null { + if (!data || data.length === 0) return null; + + // Sum all values for the metric across rows + let total = 0; + let found = false; + + for (const row of data) { + if (metricName in row) { + const value = row[metricName]; + if (typeof value === 'number' && !Number.isNaN(value)) { + total += value; + found = true; + } + } + } + + return found ? total : null; +} + +function isNumericColumn( + data: Array<Record<string, unknown>> | undefined, + colName: string, +): boolean { + if (!data || data.length === 0) return false; + + for (const row of data) { + if (colName in row) { + const value = row[colName]; + return typeof value === 'number'; + } + } + return false; +} + +/** + * Hook to get a function that checks if a chart is in an active tab. + * A chart is considered visible if: + * 1. It has no tab parents (not inside any tab) + * 2. All of its tab parents are in the active tabs list + */ +export function useIsChartInActiveTab() { + const dashboardLayout = useSelector<RootState, DashboardLayout>( + state => state.dashboardLayout?.present, + ); + const activeTabs = useSelector<RootState, ActiveTabs>( + state => state.dashboardState?.activeTabs, + ); + + const layoutChartItems = useMemo( + () => + Object.values(dashboardLayout || {}).filter( + item => item.type === CHART_TYPE, + ), + [dashboardLayout], + ); + + return useCallback( + (chartId: number): boolean => { + const chartLayoutItem = layoutChartItems.find( + layoutItem => layoutItem.meta?.chartId === chartId, + ); + const tabParents = chartLayoutItem?.parents?.filter( + (parent: string) => dashboardLayout[parent]?.type === TAB_TYPE, + ); + + // Chart is visible if it has no tab parents or all tab parents are active + return ( + !tabParents || + tabParents.length === 0 || + tabParents.every(tab => activeTabs?.includes(tab)) + ); + }, + [dashboardLayout, layoutChartItems, activeTabs], + ); +} + +/** + * Filter chart IDs to only include those in active tabs. + */ +export function useChartsInActiveTabs(chartIds: number[]): number[] { + const isChartInActiveTab = useIsChartInActiveTab(); + + return useMemo(() => { + const visibleCharts = chartIds.filter(isChartInActiveTab); + console.log('[useChartsInActiveTabs] Visible charts:', visibleCharts); + return visibleCharts; + }, [chartIds, isChartInActiveTab]); +} + +interface ChartComparisonData { + chartStatus?: string; + originalData?: Array<Record<string, unknown>>; + modifiedData?: Array<Record<string, unknown>>; + colnames?: string[]; +} + +/** + * Selector that extracts only the comparison-relevant data for specific chart IDs. + * This avoids re-renders when unrelated chart data changes. + */ +function useChartComparisonData( + chartIds: number[], +): Record<number, ChartComparisonData> { + return useSelector((state: RootState) => { + const result: Record<number, ChartComparisonData> = {}; + for (const chartId of chartIds) { + const chartState = state.charts[chartId] as + | ChartStateWithOriginal + | undefined; + if (chartState) { + const originalResponse = chartState.originalQueriesResponse?.[0] as + | QueryResponse + | undefined; + const modifiedResponse = chartState.queriesResponse?.[0] as + | QueryResponse + | undefined; + result[chartId] = { + chartStatus: chartState.chartStatus, + originalData: originalResponse?.data, + modifiedData: modifiedResponse?.data, + colnames: modifiedResponse?.colnames, + }; + } + } + return result; + }, shallowEqual); +} + +/** + * Selector that extracts chart display names and viz types for specific chart IDs. + * Uses sliceNameOverride from dashboard layout if available, otherwise falls back to slice_name. + */ +function useChartDisplayData( + chartIds: number[], +): Record<number, { displayName: string; viz_type: string }> { + return useSelector((state: RootState) => { + const slices = state.sliceEntities.slices as { [id: number]: Slice }; + const dashboardLayout = state.dashboardLayout?.present; + const result: Record<number, { displayName: string; viz_type: string }> = + {}; + + // Build a map of chartId -> sliceNameOverride from dashboard layout + const nameOverrides: Record<number, string | undefined> = {}; + if (dashboardLayout) { + for (const item of Object.values(dashboardLayout)) { + if (item.type === CHART_TYPE && item.meta?.chartId) { + nameOverrides[item.meta.chartId] = item.meta.sliceNameOverride; + } + } + } + + for (const chartId of chartIds) { + const slice = slices[chartId]; + if (slice) { + result[chartId] = { + displayName: nameOverrides[chartId] || slice.slice_name, + viz_type: slice.viz_type, + }; + } + } + return result; + }, shallowEqual); +} + +export function useChartComparison( + affectedChartIds: number[], +): ChartComparison[] { + const visibleChartIds = useChartsInActiveTabs(affectedChartIds); + const chartData = useChartComparisonData(visibleChartIds); + const chartDisplayData = useChartDisplayData(visibleChartIds); + + return useMemo(() => { + const comparisons: ChartComparison[] = []; + + console.log( + '[useChartComparison] Processing visible charts:', + visibleChartIds, + ); + + for (const chartId of visibleChartIds) { + const chartState = chartData[chartId]; + const displayData = chartDisplayData[chartId]; + + if (!chartState || !displayData) continue; + + const originalData = chartState.originalData; + const modifiedData = chartState.modifiedData; + + if (!originalData || !modifiedData) continue; + + // Get column names from the response + const colnames = chartState.colnames || []; + const metrics: ChartMetricComparison[] = []; + + for (const metricName of colnames) { + // Only include numeric columns + if (!isNumericColumn(modifiedData, metricName)) continue; + + const originalValue = extractMetricValue(originalData, metricName); + const modifiedValue = extractMetricValue(modifiedData, metricName); + + if ( + originalValue !== null && + modifiedValue !== null && + originalValue !== 0 + ) { + const percentageChange = + ((modifiedValue - originalValue) / Math.abs(originalValue)) * 100; + + metrics.push({ + metricName, + originalValue, + modifiedValue, + percentageChange, + }); + } + } + + if (metrics.length > 0) { + comparisons.push({ + chartId, + chartName: displayData.displayName, + chartType: displayData.viz_type, + metrics, + }); + } + } + + return comparisons; + }, [chartData, chartDisplayData, visibleChartIds]); +} + +/** + * Selector that extracts only loading statuses for specific chart IDs. + */ +function useChartLoadingStatuses( + chartIds: number[], +): Record<number, string | undefined> { + return useSelector((state: RootState) => { + const result: Record<number, string | undefined> = {}; + for (const chartId of chartIds) { + const chartState = state.charts[chartId] as + | ChartStateWithOriginal + | undefined; + result[chartId] = chartState?.chartStatus; + } + return result; + }, shallowEqual); +} + +/** + * Check if all affected charts (in active tabs) have finished loading. + * Returns true if no visible chart is currently in 'loading' status. + */ +export function useAllChartsLoaded(chartIds: number[]): boolean { + const visibleChartIds = useChartsInActiveTabs(chartIds); + const chartStatuses = useChartLoadingStatuses(visibleChartIds); + + return useMemo(() => { + const statuses = visibleChartIds.map(id => ({ + id, + status: chartStatuses[id], + })); + console.log('[useAllChartsLoaded] Chart statuses:', statuses); + + return visibleChartIds.every(id => chartStatuses[id] !== 'loading'); + }, [chartStatuses, visibleChartIds]); +} diff --git a/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts b/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts new file mode 100644 index 0000000000..5b03e55f97 --- /dev/null +++ b/superset-frontend/src/dashboard/components/WhatIfDrawer/whatIfApi.ts @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SupersetClient } from '@superset-ui/core'; +import { + WhatIfInterpretRequest, + WhatIfInterpretResponse, + ChartComparison, + WhatIfFilter, +} from './types'; + +interface ApiResponse { + result: { + summary: string; + insights: Array<{ + title: string; + description: string; + type: string; + }>; + raw_response?: string; + }; +} + +export async function fetchWhatIfInterpretation( + request: WhatIfInterpretRequest, +): Promise<WhatIfInterpretResponse> { + const response = await SupersetClient.post({ + endpoint: '/api/v1/what_if/interpret', + jsonPayload: { + modifications: request.modifications.map(mod => ({ + column: mod.column, + multiplier: mod.multiplier, + ...(mod.filters && mod.filters.length > 0 + ? { + filters: mod.filters.map((f: WhatIfFilter) => ({ + col: f.col, + op: f.op, + val: f.val, + })), + } + : {}), + })), + charts: request.charts.map((chart: ChartComparison) => ({ + chart_id: chart.chartId, + chart_name: chart.chartName, + chart_type: chart.chartType, + metrics: chart.metrics.map(m => ({ + metric_name: m.metricName, + original_value: m.originalValue, + modified_value: m.modifiedValue, + percentage_change: m.percentageChange, + })), + })), + dashboard_name: request.dashboardName, + }, + }); + + const data = response.json as ApiResponse; + const { result } = data; + + return { + summary: result.summary, + insights: result.insights.map(insight => ({ + title: insight.title, + description: insight.description, + type: insight.type as 'observation' | 'implication' | 'recommendation', + })), + rawResponse: result.raw_response, + }; +} diff --git a/superset-frontend/src/dashboard/types.ts b/superset-frontend/src/dashboard/types.ts index 71a534f3eb..553b26f1a6 100644 --- a/superset-frontend/src/dashboard/types.ts +++ b/superset-frontend/src/dashboard/types.ts @@ -285,9 +285,27 @@ export type Slice = { /** * What-If Analysis types */ +export type WhatIfFilterOperator = + | '==' + | '!=' + | '>' + | '<' + | '>=' + | '<=' + | 'IN' + | 'NOT IN' + | 'TEMPORAL_RANGE'; + +export interface WhatIfFilter { + col: string; + op: WhatIfFilterOperator; + val: string | number | boolean | Array<string | number>; +} + export interface WhatIfModification { column: string; multiplier: number; + filters?: WhatIfFilter[]; } export interface WhatIfColumn { diff --git a/superset/config.py b/superset/config.py index a33294ed65..a73b80008b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1145,6 +1145,18 @@ QUERY_LOGGER = None # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "") +# --------------------------------------------------- +# What-If AI Interpretation Configuration +# --------------------------------------------------- +# API key for OpenRouter (required for AI interpretation of what-if analysis) +OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY") +# Model to use for interpretation (default: x-ai/grok-4.1-fast) +OPENROUTER_MODEL = "x-ai/grok-4.1-fast" +# API base URL for OpenRouter +OPENROUTER_API_BASE = "https://openrouter.ai/api/v1" +# Request timeout in seconds +OPENROUTER_TIMEOUT = 30 + # Maximum number of rows returned for any analytical database query SQL_MAX_ROW = 100000 diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 4f4fd361a4..8d29cd471f 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -223,6 +223,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods from superset.views.user_registrations import UserRegistrationsView from superset.views.users.api import CurrentUserRestApi, UserRestApi from superset.views.users_list import UsersListView + from superset.what_if.api import WhatIfRestApi set_app_error_handlers(self.superset_app) self.register_request_handlers() @@ -266,6 +267,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods appbuilder.add_api(RLSRestApi) appbuilder.add_api(SavedQueryRestApi) appbuilder.add_api(TagRestApi) + appbuilder.add_api(WhatIfRestApi) appbuilder.add_api(SqlLabRestApi) appbuilder.add_api(SqlLabPermalinkRestApi) appbuilder.add_api(LogRestApi) diff --git a/superset/what_if/__init__.py b/superset/what_if/__init__.py new file mode 100644 index 0000000000..8ee09d2dcc --- /dev/null +++ b/superset/what_if/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis module for AI-powered interpretation of scenario analysis.""" diff --git a/superset/what_if/api.py b/superset/what_if/api.py new file mode 100644 index 0000000000..62173022ea --- /dev/null +++ b/superset/what_if/api.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis REST API.""" + +from __future__ import annotations + +import logging + +from flask import request, Response +from flask_appbuilder.api import expose, protect, safe +from marshmallow import ValidationError + +from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP +from superset.extensions import event_logger +from superset.views.base_api import BaseSupersetApi, statsd_metrics +from superset.what_if.commands.interpret import WhatIfInterpretCommand +from superset.what_if.exceptions import OpenRouterAPIError, OpenRouterConfigError +from superset.what_if.schemas import ( + WhatIfInterpretRequestSchema, + WhatIfInterpretResponseSchema, +) + +logger = logging.getLogger(__name__) + + +class WhatIfRestApi(BaseSupersetApi): + """REST API for What-If Analysis features.""" + + resource_name = "what_if" + allow_browser_login = True + openapi_spec_tag = "What-If Analysis" + + # Use Dashboard permissions since what-if is a dashboard feature + class_permission_name = "Dashboard" + method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP + + @expose("/interpret", methods=("POST",)) + @event_logger.log_this + @protect() + @safe + @statsd_metrics + def interpret(self) -> Response: + """Generate AI interpretation of what-if analysis results. + --- + post: + summary: Generate AI interpretation of what-if changes + description: >- + Sends what-if modification data to an LLM for business interpretation. + Returns a summary and actionable insights. + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/WhatIfInterpretRequestSchema' + responses: + 200: + description: AI interpretation generated successfully + content: + application/json: + schema: + type: object + properties: + result: + $ref: '#/components/schemas/WhatIfInterpretResponseSchema' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + 502: + description: Error communicating with AI service + content: + application/json: + schema: + type: object + properties: + message: + type: string + security: + - jwt: [] + """ + try: + request_data = WhatIfInterpretRequestSchema().load(request.json) + except ValidationError as ex: + logger.warning("Invalid request data: %s", ex.messages) + return self.response_400(message=str(ex.messages)) + + try: + command = WhatIfInterpretCommand(request_data) + result = command.run() + return self.response( + 200, result=WhatIfInterpretResponseSchema().dump(result) + ) + except OpenRouterConfigError as ex: + logger.error("OpenRouter configuration error: %s", ex) + return self.response(500, message="AI interpretation is not configured") + except OpenRouterAPIError as ex: + logger.error("OpenRouter API error: %s", ex) + return self.response(502, message=str(ex)) + except ValueError as ex: + logger.warning("Invalid request: %s", ex) + return self.response_400(message=str(ex)) diff --git a/superset/what_if/commands/__init__.py b/superset/what_if/commands/__init__.py new file mode 100644 index 0000000000..8135fd2ad3 --- /dev/null +++ b/superset/what_if/commands/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis commands.""" diff --git a/superset/what_if/commands/interpret.py b/superset/what_if/commands/interpret.py new file mode 100644 index 0000000000..1611c83d28 --- /dev/null +++ b/superset/what_if/commands/interpret.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis interpretation command using OpenRouter.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import httpx +from flask import current_app + +from superset.commands.base import BaseCommand +from superset.what_if.exceptions import OpenRouterAPIError, OpenRouterConfigError + +logger = logging.getLogger(__name__) + + +class WhatIfInterpretCommand(BaseCommand): + """Command to get AI interpretation of what-if analysis results.""" + + def __init__(self, data: dict[str, Any]) -> None: + self._data = data + + def run(self) -> dict[str, Any]: + self.validate() + return self._get_ai_interpretation() + + def validate(self) -> None: + api_key = current_app.config.get("OPENROUTER_API_KEY") + if not api_key: + raise OpenRouterConfigError("OPENROUTER_API_KEY not configured") + + if not self._data.get("modifications"): + raise ValueError("At least one modification is required") + + if not self._data.get("charts"): + raise ValueError("At least one chart comparison is required") + + def _format_filter(self, flt: dict[str, Any]) -> str: + """Format a single filter for display in the prompt.""" + col = flt.get("col", "") + op = flt.get("op", "") + val = flt.get("val", "") + + # Format the value based on type + if isinstance(val, list): + val_str = ", ".join(str(v) for v in val) + return f"{col} {op} [{val_str}]" + if isinstance(val, str) and op == "TEMPORAL_RANGE": + return f"{col} in time range '{val}'" + return f"{col} {op} {val}" + + def _build_prompt(self) -> str: + modifications = self._data["modifications"] + charts = self._data["charts"] + dashboard_name = self._data.get("dashboard_name") or "Dashboard" + + # Build modification description + mod_descriptions = [] + for mod in modifications: + pct_change = (mod["multiplier"] - 1) * 100 + sign = "+" if pct_change >= 0 else "" + base_desc = f"- {mod['column']}: {sign}{pct_change:.1f}%" + + # Add filter conditions if present + filters = mod.get("filters") or [] + if filters: + filter_strs = [self._format_filter(f) for f in filters] + filter_desc = " AND ".join(filter_strs) + base_desc += f" (only where {filter_desc})" + + mod_descriptions.append(base_desc) + + modifications_text = "\n".join(mod_descriptions) + + # Build chart impact summary + chart_summaries = [] + for chart in charts: + metrics_text = [] + for metric in chart["metrics"]: + sign = "+" if metric["percentage_change"] >= 0 else "" + metrics_text.append( + f" - {metric['metric_name']}: " + f"{metric['original_value']:,.2f} -> {metric['modified_value']:,.2f} " + f"({sign}{metric['percentage_change']:.1f}%)" + ) + chart_summaries.append( + f"**{chart['chart_name']}** ({chart['chart_type']}):\n" + + "\n".join(metrics_text) + ) + + charts_text = "\n\n".join(chart_summaries) + + return f"""You are a business intelligence analyst. A user is performing a what-if analysis on their "{dashboard_name}" dashboard. + +## Scenario +The user modified the following column(s): +{modifications_text} + +## Impact on Charts +{charts_text} + +## Your Task +Analyze this what-if scenario and provide: + +1. **Summary**: A 1-2 sentence executive summary of the overall impact. + +2. **Key Observations**: 2-3 specific observations about how the changes affected different metrics. + +3. **Business Implications**: What does this mean for the business? Consider: + - Revenue/cost implications + - Operational efficiency + - Risk factors + +4. **Recommendations**: 1-2 actionable recommendations based on this analysis. + +Please be concise, specific, and focus on business value. Use the actual numbers from the data. + +Respond in JSON format: +{{ + "summary": "...", + "insights": [ + {{"title": "...", "description": "...", "type": "observation"}}, + {{"title": "...", "description": "...", "type": "implication"}}, + {{"title": "...", "description": "...", "type": "recommendation"}} + ] +}}""" + + def _get_ai_interpretation(self) -> dict[str, Any]: + api_key = current_app.config.get("OPENROUTER_API_KEY") + model = current_app.config.get("OPENROUTER_MODEL", "x-ai/grok-4.1-fast") + api_base = current_app.config.get( + "OPENROUTER_API_BASE", "https://openrouter.ai/api/v1" + ) + timeout = current_app.config.get("OPENROUTER_TIMEOUT", 30) + + prompt = self._build_prompt() + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "HTTP-Referer": current_app.config.get("WEBDRIVER_BASEURL", ""), + "X-Title": "Apache Superset What-If Analysis", + } + + payload = { + "model": model, + "messages": [ + { + "role": "system", + "content": ( + "You are a business intelligence analyst. " + "Respond only with valid JSON." + ), + }, + {"role": "user", "content": prompt}, + ], + "temperature": 0.3, + "max_tokens": 1000, + "response_format": {"type": "json_object"}, + } + + try: + with httpx.Client(timeout=timeout) as client: + response = client.post( + f"{api_base}/chat/completions", + headers=headers, + json=payload, + ) + response.raise_for_status() + + result = response.json() + content = result["choices"][0]["message"]["content"] + + # Parse the JSON response + parsed = json.loads(content) + return { + "summary": parsed.get("summary", ""), + "insights": parsed.get("insights", []), + "raw_response": content if current_app.debug else None, + } + + except httpx.HTTPStatusError as ex: + logger.error("OpenRouter API error: %s", ex.response.status_code) + raise OpenRouterAPIError( + f"OpenRouter API error: {ex.response.status_code}" + ) from ex + except json.JSONDecodeError as ex: + logger.error("Failed to parse AI response: %s", ex) + raise OpenRouterAPIError("Failed to parse AI response") from ex + except httpx.TimeoutException as ex: + logger.error("OpenRouter API timeout") + raise OpenRouterAPIError("AI service timed out") from ex + except Exception as ex: + logger.exception("Unexpected error calling OpenRouter") + raise OpenRouterAPIError(f"Unexpected error: {ex!s}") from ex diff --git a/superset/what_if/exceptions.py b/superset/what_if/exceptions.py new file mode 100644 index 0000000000..666423f29c --- /dev/null +++ b/superset/what_if/exceptions.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis exceptions.""" + +from superset.exceptions import SupersetException + + +class WhatIfException(SupersetException): + """Base exception for What-If Analysis errors.""" + + +class OpenRouterConfigError(WhatIfException): + """Raised when OpenRouter API is not configured.""" + + status = 500 + message = "OpenRouter API is not configured" + + +class OpenRouterAPIError(WhatIfException): + """Raised when there is an error communicating with OpenRouter API.""" + + status = 502 + message = "Error communicating with OpenRouter API" diff --git a/superset/what_if/schemas.py b/superset/what_if/schemas.py new file mode 100644 index 0000000000..df0f4ad2b4 --- /dev/null +++ b/superset/what_if/schemas.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""What-If Analysis schemas for request/response validation.""" + +from marshmallow import fields, Schema + + +class ChartMetricComparisonSchema(Schema): + """Schema for a single metric comparison within a chart.""" + + metric_name = fields.String( + required=True, + metadata={"description": "Name of the metric being compared"}, + ) + original_value = fields.Float( + required=True, + metadata={"description": "Original metric value before modification"}, + ) + modified_value = fields.Float( + required=True, + metadata={"description": "Modified metric value after what-if applied"}, + ) + percentage_change = fields.Float( + required=True, + metadata={"description": "Percentage change from original to modified"}, + ) + + +class ChartComparisonSchema(Schema): + """Schema for chart-level comparison data.""" + + chart_id = fields.Integer( + required=True, + metadata={"description": "Unique identifier for the chart"}, + ) + chart_name = fields.String( + required=True, + metadata={"description": "Display name of the chart"}, + ) + chart_type = fields.String( + required=True, + metadata={"description": "Visualization type (e.g., bar, line, pie)"}, + ) + metrics = fields.List( + fields.Nested(ChartMetricComparisonSchema), + required=True, + metadata={"description": "List of metric comparisons for this chart"}, + ) + + +class WhatIfFilterSchema(Schema): + """Schema for a what-if filter condition.""" + + col = fields.String( + required=True, + metadata={"description": "Column name to filter on"}, + ) + op = fields.String( + required=True, + metadata={ + "description": "Filter operator: ==, !=, >, <, >=, <=, IN, NOT IN, TEMPORAL_RANGE" + }, + ) + val = fields.Raw( + required=True, + metadata={ + "description": "Filter value (string, number, or array for IN/NOT IN operators)" + }, + ) + + +class ModificationSchema(Schema): + """Schema for a single what-if modification.""" + + column = fields.String( + required=True, + metadata={"description": "Column name being modified"}, + ) + multiplier = fields.Float( + required=True, + metadata={ + "description": "Multiplier applied to the column (e.g., 1.1 for +10%)" + }, + ) + filters = fields.List( + fields.Nested(WhatIfFilterSchema), + required=False, + load_default=None, + metadata={ + "description": "Optional filters to apply modification conditionally" + }, + ) + + +class WhatIfInterpretRequestSchema(Schema): + """Schema for what-if interpretation request.""" + + modifications = fields.List( + fields.Nested(ModificationSchema), + required=True, + metadata={"description": "List of column modifications applied"}, + ) + charts = fields.List( + fields.Nested(ChartComparisonSchema), + required=True, + metadata={"description": "List of charts with comparison data"}, + ) + dashboard_name = fields.String( + required=False, + load_default=None, + metadata={"description": "Name of the dashboard for context"}, + ) + + +class InsightSchema(Schema): + """Schema for a single AI-generated insight.""" + + title = fields.String( + required=True, + metadata={"description": "Short title summarizing the insight"}, + ) + description = fields.String( + required=True, + metadata={"description": "Detailed description of the insight"}, + ) + type = fields.String( + required=True, + metadata={ + "description": "Type of insight: observation, implication, or recommendation" + }, + ) + + +class WhatIfInterpretResponseSchema(Schema): + """Schema for what-if interpretation response.""" + + summary = fields.String( + required=True, + metadata={"description": "Executive summary of the what-if analysis"}, + ) + insights = fields.List( + fields.Nested(InsightSchema), + required=True, + metadata={"description": "List of AI-generated insights"}, + ) + raw_response = fields.String( + required=False, + metadata={"description": "Raw AI response (only in debug mode)"}, + )
