This is an automated email from the ASF dual-hosted git repository. arivero pushed a commit to branch table-time-comparison in repository https://gitbox.apache.org/repos/asf/superset.git
commit c3d04b3fa870cd8338edbcd4888bf491e5fac57d Author: Antonio Rivero <[email protected]> AuthorDate: Fri Mar 1 14:41:35 2024 +0100 Table with Time Comparison: - Using one single query with some new properties added to QueryObject so we generate the comparison data instead of two queries - Use joins when generating the comparison query - Add time comparison control to Table chart - Render Time comparison metrics in Table chart - Render header with column name on top of each group of 4 metrics columns - Modify useSticky to consider multiple rows of headers when computing the columns widths - Add tests for new query building function --- .../superset-ui-core/src/query/types/Query.ts | 10 + .../src/query/types/QueryResponse.ts | 1 + .../plugin-chart-table/src/DataTable/DataTable.tsx | 45 ++- .../src/DataTable/hooks/useSticky.tsx | 4 +- .../plugins/plugin-chart-table/src/TableChart.tsx | 42 +++ .../plugins/plugin-chart-table/src/buildQuery.ts | 58 +++- .../plugin-chart-table/src/controlPanel.tsx | 72 ++++ .../plugin-chart-table/src/transformProps.ts | 165 ++++++++- .../plugins/plugin-chart-table/src/types.ts | 2 + .../plugin-chart-table/src/utils/isEqualColumns.ts | 3 +- .../plugins/plugin-chart-table/test/testData.ts | 1 + superset/charts/schemas.py | 20 +- superset/common/query_context_processor.py | 3 + superset/common/query_object.py | 4 + superset/connectors/sqla/models.py | 126 ++++++- superset/constants.py | 1 + tests/unit_tests/connectors/__init__.py | 16 + tests/unit_tests/connectors/test_models.py | 383 +++++++++++++++++++++ tests/unit_tests/queries/query_object_test.py | 1 + 19 files changed, 942 insertions(+), 15 deletions(-) diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts index 718f10514c..db3a090dd6 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/Query.ts @@ -77,6 +77,13 @@ export type ResidualQueryObjectData = { [key: string]: unknown; }; +export type QueryObjectInstantTimeComparisonInfo = { + /** The range to use as comparison range */ + range: string; + /** The custom filter value to use if range is Custom */ + filter?: QueryObjectFilterClause; +}; + /** * Query object directly compatible with the new chart data API. * A stricter version of query form data. @@ -149,6 +156,9 @@ export interface QueryObject series_columns?: QueryFormColumn[]; series_limit?: number; series_limit_metric?: Maybe<QueryFormMetric>; + + /** Instant Time Comparison */ + instant_time_comparison_info?: QueryObjectInstantTimeComparisonInfo; } export interface QueryContext { diff --git a/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts b/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts index 1705814df1..d910e9a778 100644 --- a/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts +++ b/superset-frontend/packages/superset-ui-core/src/query/types/QueryResponse.ts @@ -78,6 +78,7 @@ export interface ChartDataResponseResult { | 'timed_out'; from_dttm: number | null; to_dttm: number | null; + instant_time_comparison_range: string | null; } export interface TimeseriesChartDataResponseResult diff --git a/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx b/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx index 6c5123806f..a0af54eb6a 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/DataTable/DataTable.tsx @@ -67,6 +67,7 @@ export interface DataTableProps<D extends object> extends TableOptions<D> { rowCount: number; wrapperRef?: MutableRefObject<HTMLDivElement>; onColumnOrderChange: () => void; + groupHeaderColumns?: Record<string, number[]>; } export interface RenderHTMLCellProps extends HTMLProps<HTMLTableCellElement> { @@ -99,6 +100,7 @@ export default typedMemo(function DataTable<D extends object>({ serverPagination, wrapperRef: userWrapperRef, onColumnOrderChange, + groupHeaderColumns, ...moreUseTableOptions }: DataTableProps<D>): JSX.Element { const tableHooks: PluginHook<D>[] = [ @@ -248,14 +250,55 @@ export default typedMemo(function DataTable<D extends object>({ e.preventDefault(); }; + const renderDynamicHeaders = () => { + // TODO: Make use of ColumnGroup to render the aditional headers + const headers: any = []; + let currentColumnIndex = 0; + + Object.entries(groupHeaderColumns || {}).forEach(([key, value], index) => { + // Calculate the number of placeholder columns needed before the current header + const startPosition = value[0]; + const colSpan = value.length; + + // Add placeholder <th> for columns before this header + for (let i = currentColumnIndex; i < startPosition; i += 1) { + headers.push( + <th + key={`placeholder-${i}`} + style={{ borderBottom: 0 }} + aria-label={`Header-${i}`} + />, + ); + } + + // Add the current header <th> + headers.push( + <th key={`header-${key}`} colSpan={colSpan} style={{ borderBottom: 0 }}> + {key} + </th>, + ); + + // Update the current column index + currentColumnIndex = startPosition + colSpan; + }); + + return headers; + }; + const renderTable = () => ( <table {...getTableProps({ className: tableClassName })}> <thead> + {/* Render dynamic headers based on resultMap */} + {groupHeaderColumns ? <tr>{renderDynamicHeaders()}</tr> : null} {headerGroups.map(headerGroup => { const { key: headerGroupKey, ...headerGroupProps } = headerGroup.getHeaderGroupProps(); return ( - <tr key={headerGroupKey || headerGroup.id} {...headerGroupProps}> + <tr + key={headerGroupKey || headerGroup.id} + {...headerGroupProps} + style={{ borderTop: 0 }} + > {headerGroup.headers.map(column => column.render('Header', { key: column.id, diff --git a/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx b/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx index ba3466bb40..1e56987486 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/DataTable/hooks/useSticky.tsx @@ -181,7 +181,9 @@ function StickyWrap({ } const fullTableHeight = (bodyThead.parentNode as HTMLTableElement) .clientHeight; - const ths = bodyThead.childNodes[0] + // instead of always using the first tr, we use the last one to support + // multi-level headers assuming the last one is the more detailed one + const ths = bodyThead.childNodes?.[bodyThead.childNodes?.length - 1 || 0] .childNodes as NodeListOf<HTMLTableHeaderCellElement>; const widths = Array.from(ths).map( th => th.getBoundingClientRect()?.width || th.clientWidth, diff --git a/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx b/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx index 840020cad8..d4d5de970a 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/TableChart.tsx @@ -50,6 +50,7 @@ import { tn, } from '@superset-ui/core'; +import { isEmpty } from 'lodash'; import { DataColumnMeta, TableChartTransformedProps } from './types'; import DataTable, { DataTableProps, @@ -238,6 +239,7 @@ export default function TableChart<D extends DataRecord = DataRecord>( allowRearrangeColumns = false, onContextMenu, emitCrossFilters, + enableTimeComparison, } = props; const timestampFormatter = useCallback( value => getTimeFormatterForGranularity(timeGrain)(value), @@ -413,6 +415,37 @@ export default function TableChart<D extends DataRecord = DataRecord>( } : undefined; + const comparisonLabels = [t('Main'), '#', '△', '%']; + + const getHeaderColumns = ( + columnsMeta: DataColumnMeta[], + enableTimeComparison?: boolean, + ) => { + const resultMap: Record<string, number[]> = {}; + + if (!enableTimeComparison) { + return resultMap; + } + + columnsMeta.forEach((element, index) => { + // Check if element's label is one of the comparison labels + if (comparisonLabels.includes(element.label)) { + // Extract the key portion after the space, assuming the format is always "label key" + const keyPortion = element.key.split(' ')[1]; + + // If the key portion is not in the map, initialize it with the current index + if (!resultMap[keyPortion]) { + resultMap[keyPortion] = [index]; + } else { + // Add the index to the existing array + resultMap[keyPortion].push(index); + } + } + }); + + return resultMap; + }; + const getColumnConfigs = useCallback( (column: DataColumnMeta, i: number): ColumnWithLooseAccessor<D> => { const { @@ -596,6 +629,7 @@ export default function TableChart<D extends DataRecord = DataRecord>( style={{ ...sharedStyle, ...style, + borderTop: 0, }} tabIndex={0} onKeyDown={(e: React.KeyboardEvent<HTMLElement>) => { @@ -670,6 +704,11 @@ export default function TableChart<D extends DataRecord = DataRecord>( [columnsMeta, getColumnConfigs], ); + const groupHeaderColumns = useMemo( + () => getHeaderColumns(columnsMeta, enableTimeComparison), + [columnsMeta, enableTimeComparison], + ); + const handleServerPaginationChange = useCallback( (pageNumber: number, pageSize: number) => { updateExternalFormData(setDataMask, pageNumber, pageSize); @@ -734,6 +773,9 @@ export default function TableChart<D extends DataRecord = DataRecord>( selectPageSize={pageSize !== null && SelectPageSize} // not in use in Superset, but needed for unit tests sticky={sticky} + groupHeaderColumns={ + !isEmpty(groupHeaderColumns) ? groupHeaderColumns : undefined + } /> </Styles> ); diff --git a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts index 69631a5f35..9e93c268a4 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts +++ b/superset-frontend/plugins/plugin-chart-table/src/buildQuery.ts @@ -19,11 +19,17 @@ import { AdhocColumn, buildQueryContext, + buildQueryObject, + ComparisonTimeRangeType, ensureIsArray, + FeatureFlag, + getComparisonInfo, getMetricLabel, + isFeatureEnabled, isPhysicalColumn, QueryMode, QueryObject, + QueryObjectFilterClause, removeDuplicates, } from '@superset-ui/core'; import { PostProcessingRule } from '@superset-ui/core/src/query/types/PostProcessing'; @@ -55,7 +61,12 @@ const buildQuery: BuildQuery<TableChartFormData> = ( percent_metrics: percentMetrics, order_desc: orderDesc = false, extra_form_data, + time_comparison: timeComparison, + enable_time_comparison, } = formData; + const canUseTimeComparison = + enable_time_comparison && + isFeatureEnabled(FeatureFlag.ChartPluginsExperimental); const queryMode = getQueryMode(formData); const sortByMetric = ensureIsArray(formData.timeseries_limit_metric)[0]; const time_grain_sqla = @@ -69,6 +80,34 @@ const buildQuery: BuildQuery<TableChartFormData> = ( }; } + const addComparisonPercentMetrics = (metrics: string[]) => + metrics.reduce((acc, metric) => { + const prevMetric = `prev_${metric}`; + return acc.concat([metric, prevMetric]); + }, [] as string[]); + + const comparisonFormData = getComparisonInfo( + formDataCopy, + timeComparison, + extra_form_data, + ); + + const getFirstTemporalFilter = ( + queryObject?: QueryObject, + ): QueryObjectFilterClause | undefined => { + const { filters = [] } = queryObject || {}; + const timeFilterIndex: number = + filters?.findIndex( + filter => 'op' in filter && filter.op === 'TEMPORAL_RANGE', + ) ?? -1; + + const timeFilter: QueryObjectFilterClause | undefined = + timeFilterIndex !== -1 && filters ? filters[timeFilterIndex] : undefined; + return timeFilter; + }; + const comparisonQueryObject = buildQueryObject(comparisonFormData); + const firstTemporalFilter = getFirstTemporalFilter(comparisonQueryObject); + return buildQueryContext(formDataCopy, baseQueryObject => { let { metrics, orderby = [], columns = [] } = baseQueryObject; let postProcessing: PostProcessingRule[] = []; @@ -85,8 +124,11 @@ const buildQuery: BuildQuery<TableChartFormData> = ( } // add postprocessing for percent metrics only when in aggregation mode if (percentMetrics && percentMetrics.length > 0) { + const percentMetricsLabelsWithTimeComparison = canUseTimeComparison + ? addComparisonPercentMetrics(percentMetrics.map(getMetricLabel)) + : percentMetrics.map(getMetricLabel); const percentMetricLabels = removeDuplicates( - percentMetrics.map(getMetricLabel), + percentMetricsLabelsWithTimeComparison, ); metrics = removeDuplicates( metrics.concat(percentMetrics), @@ -139,6 +181,20 @@ const buildQuery: BuildQuery<TableChartFormData> = ( ...moreProps, }; + // Customize the query for time comparison + if (canUseTimeComparison) { + queryObject = { + ...queryObject, + instant_time_comparison_info: { + range: timeComparison, + filter: + timeComparison === ComparisonTimeRangeType.Custom + ? firstTemporalFilter + : undefined, + }, + }; + } + if ( formData.server_pagination && options?.extras?.cachedChanges?.[formData.slice_id] && diff --git a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx index ad39b504cb..c7710c81df 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx @@ -20,14 +20,18 @@ import React from 'react'; import { ChartDataResponseResult, + ComparisonTimeRangeType, ensureIsArray, + FeatureFlag, GenericDataType, isAdhocColumn, + isFeatureEnabled, isPhysicalColumn, QueryFormColumn, QueryMode, smartDateFormatter, t, + validateTimeComparisonRangeValues, } from '@superset-ui/core'; import { ColumnOption, @@ -257,6 +261,74 @@ const config: ControlPanelConfig = { }, ], ['adhoc_filters'], + [ + { + name: 'enable_time_comparison', + config: { + type: 'CheckboxControl', + label: t('Enable Time Comparison'), + description: t('Enable time comparison (experimental feature)'), + default: false, + visibility: () => + isFeatureEnabled(FeatureFlag.ChartPluginsExperimental), + }, + }, + ], + [ + { + name: 'time_comparison', + config: { + type: 'SelectControl', + label: t('Range for Comparison'), + default: 'r', + choices: [ + ['r', 'Inherit range from time filters'], + ['y', 'Year'], + ['m', 'Month'], + ['w', 'Week'], + ['c', 'Custom'], + ], + rerender: ['adhoc_custom'], + description: t( + 'Set the time range that will be used for the comparison metrics. ' + + 'For example, "Year" will compare to the same dates one year earlier. ' + + 'Use "Inherit range from time filters" to shift the comparison time range' + + 'by the same length as your time range and use "Custom" to set a custom comparison range.', + ), + visibility: ({ controls }) => + Boolean(controls?.enable_time_comparison?.value) && + isFeatureEnabled(FeatureFlag.ChartPluginsExperimental), + }, + }, + ], + [ + { + name: `adhoc_custom`, + config: { + ...sharedControls.adhoc_filters, + label: t('Filters for Comparison'), + description: + 'This only applies when selecting the Range for Comparison Type: Custom', + visibility: ({ controls }) => + Boolean(controls?.enable_time_comparison?.value) && + controls?.time_comparison?.value === + ComparisonTimeRangeType.Custom, + mapStateToProps: ( + state: ControlPanelState, + controlState: ControlState, + ) => ({ + ...(sharedControls.adhoc_filters.mapStateToProps?.( + state, + controlState, + ) || {}), + externalValidationErrors: validateTimeComparisonRangeValues( + state.controls?.time_comparison?.value, + controlState.value, + ), + }), + }, + }, + ], [ { name: 'timeseries_limit_metric', diff --git a/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts b/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts index 0a2a3449c6..e36684baff 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts +++ b/superset-frontend/plugins/plugin-chart-table/src/transformProps.ts @@ -21,14 +21,17 @@ import { CurrencyFormatter, DataRecord, extractTimegrain, + FeatureFlag, GenericDataType, getMetricLabel, getNumberFormatter, getTimeFormatter, getTimeFormatterForGranularity, + isFeatureEnabled, NumberFormats, QueryMode, smartDateFormatter, + t, TimeFormats, TimeFormatter, } from '@superset-ui/core'; @@ -48,6 +51,8 @@ import { const { PERCENT_3_POINT } = NumberFormats; const { DATABASE_DATETIME } = TimeFormats; +const COMPARISON_PREFIX = 'prev_'; + function isNumeric(key: string, data: DataRecord[] = []) { return data.every( x => x[key] === null || x[key] === undefined || typeof x[key] === 'number', @@ -81,6 +86,88 @@ const processDataRecords = memoizeOne(function processDataRecords( return data; }); +const calculateDifferences = ( + originalValue: number, + comparisonValue: number, +) => { + const valueDifference = originalValue - comparisonValue; + let percentDifferenceNum; + if (!originalValue && !comparisonValue) { + percentDifferenceNum = 0; + } else if (!originalValue || !comparisonValue) { + percentDifferenceNum = originalValue ? 1 : -1; + } else { + percentDifferenceNum = + (originalValue - comparisonValue) / Math.abs(comparisonValue); + } + return { valueDifference, percentDifferenceNum }; +}; + +const processComparisonTotals = (totals: DataRecord | undefined) => { + if (!totals) { + return totals; + } + const transformedTotals: DataRecord = {}; + Object.keys(totals).forEach(key => { + if (totals[key] !== undefined && !key.includes(COMPARISON_PREFIX)) { + transformedTotals[`Main ${key}`] = totals[key]; + transformedTotals[`# ${key}`] = totals[`${COMPARISON_PREFIX}${key}`]; + const { valueDifference, percentDifferenceNum } = calculateDifferences( + totals[key] as number, + totals[`${COMPARISON_PREFIX}${key}`] as number, + ); + transformedTotals[`△ ${key}`] = valueDifference; + transformedTotals[`% ${key}`] = percentDifferenceNum; + } + }); + return transformedTotals; +}; + +const processComparisonDataRecords = memoizeOne( + function processComparisonDataRecords( + originalData: DataRecord[] | undefined, + originalColumns: DataColumnMeta[], + ) { + // Transform data + return originalData?.map(originalItem => { + const transformedItem: DataRecord = {}; + originalColumns.forEach(origCol => { + if ( + (origCol.isMetric || origCol.isPercentMetric) && + !origCol.key.includes(COMPARISON_PREFIX) && + origCol.isNumeric + ) { + const originalValue = originalItem[origCol.key] || 0; + const comparisonValue = origCol.isMetric + ? originalItem?.[`${COMPARISON_PREFIX}${origCol.key}`] || 0 + : originalItem[`%${COMPARISON_PREFIX}${origCol.key.slice(1)}`] || 0; + const { valueDifference, percentDifferenceNum } = + calculateDifferences( + originalValue as number, + comparisonValue as number, + ); + + transformedItem[`Main ${origCol.key}`] = originalValue; + transformedItem[`# ${origCol.key}`] = comparisonValue; + transformedItem[`△ ${origCol.key}`] = valueDifference; + transformedItem[`% ${origCol.key}`] = percentDifferenceNum; + } + }); + + Object.keys(originalItem).forEach(key => { + const isMetricOrPercentMetric = originalColumns.some( + col => col.key === key && (col.isMetric || col.isPercentMetric), + ); + if (!isMetricOrPercentMetric) { + transformedItem[key] = originalItem[key]; + } + }); + + return transformedItem; + }); + }, +); + const processColumns = memoizeOne(function processColumns( props: TableChartProps, ) { @@ -186,6 +273,55 @@ const processColumns = memoizeOne(function processColumns( ]; }, isEqualColumns); +const processComparisonColumns = ( + columns: DataColumnMeta[], + props: TableChartProps, +) => + columns + .map(col => { + const { + datasource: { columnFormats }, + rawFormData: { column_config: columnConfig = {} }, + } = props; + const config = columnConfig[col.key] || {}; + const savedFormat = columnFormats?.[col.key]; + const numberFormat = config.d3NumberFormat || savedFormat; + if (col.isNumeric && !col.key.includes(COMPARISON_PREFIX)) { + return [ + { + ...col, + label: t('Main'), + key: `${t('Main')} ${col.key}`, + }, + { + ...col, + label: `#`, + key: `# ${col.key}`, + }, + { + ...col, + label: `△`, + key: `△ ${col.key}`, + }, + { + ...col, + formatter: getNumberFormatter(numberFormat || PERCENT_3_POINT), + label: `%`, + key: `% ${col.key}`, + }, + ]; + } + if ( + !col.isMetric && + !col.isPercentMetric && + !col.key.includes(COMPARISON_PREFIX) + ) { + return [col]; + } + return []; + }) + .flat(); + /** * Automatically set page size based on number of cells. */ @@ -238,23 +374,35 @@ const transformProps = ( show_totals: showTotals, conditional_formatting: conditionalFormatting, allow_rearrange_columns: allowRearrangeColumns, + enable_time_comparison: enableTimeComparison = false, } = formData; + const canUseTimeComparison = + enableTimeComparison && + isFeatureEnabled(FeatureFlag.ChartPluginsExperimental); const timeGrain = extractTimegrain(formData); const [metrics, percentMetrics, columns] = processColumns(chartProps); + let comparisonColumns: DataColumnMeta[] = []; + if (canUseTimeComparison) { + comparisonColumns = processComparisonColumns(columns, chartProps); + } let baseQuery; let countQuery; let totalQuery; let rowCount; + const queriesDataWithoutComparisonQueries = queriesData.filter( + ({ instant_time_comparison_range }) => !instant_time_comparison_range, + ); if (serverPagination) { - [baseQuery, countQuery, totalQuery] = queriesData; + [baseQuery, countQuery, totalQuery] = queriesDataWithoutComparisonQueries; rowCount = (countQuery?.data?.[0]?.rowcount as number) ?? 0; } else { - [baseQuery, totalQuery] = queriesData; + [baseQuery, totalQuery] = queriesDataWithoutComparisonQueries; rowCount = baseQuery?.rowcount ?? 0; } const data = processDataRecords(baseQuery?.data, columns); + const comparisonData = processComparisonDataRecords(baseQuery?.data, columns); const totals = showTotals && queryMode === QueryMode.Aggregate ? totalQuery?.data[0] @@ -262,13 +410,19 @@ const transformProps = ( const columnColorFormatters = getColorFormatters(conditionalFormatting, data) ?? defaultColorFormatters; + const comparisonTotals = processComparisonTotals(totals); + + const passedData = canUseTimeComparison ? comparisonData || [] : data; + const passedTotals = canUseTimeComparison ? comparisonTotals : totals; + const passedColumns = canUseTimeComparison ? comparisonColumns : columns; + return { height, width, isRawRecords: queryMode === QueryMode.Raw, - data, - totals, - columns, + data: passedData, + totals: passedTotals, + columns: passedColumns, serverPagination, metrics, percentMetrics, @@ -292,6 +446,7 @@ const transformProps = ( timeGrain, allowRearrangeColumns, onContextMenu, + enableTimeComparison: canUseTimeComparison, }; }; diff --git a/superset-frontend/plugins/plugin-chart-table/src/types.ts b/superset-frontend/plugins/plugin-chart-table/src/types.ts index 02bae809fe..1806eddb1a 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/types.ts +++ b/superset-frontend/plugins/plugin-chart-table/src/types.ts @@ -91,6 +91,7 @@ export type TableChartFormData = QueryFormData & { time_grain_sqla?: TimeGranularity; column_config?: Record<string, TableColumnConfig>; allow_rearrange_columns?: boolean; + enable_time_comparison?: boolean; }; export interface TableChartProps extends ChartProps { @@ -135,6 +136,7 @@ export interface TableChartTransformedProps<D extends DataRecord = DataRecord> { clientY: number, filters?: ContextMenuFilters, ) => void; + enableTimeComparison?: boolean; } export default {}; diff --git a/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts b/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts index 28731c73c2..8153ea856a 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts +++ b/superset-frontend/plugins/plugin-chart-table/src/utils/isEqualColumns.ts @@ -41,6 +41,7 @@ export default function isEqualColumns( JSON.stringify(a.formData.extraFormData || null) === JSON.stringify(b.formData.extraFormData || null) && JSON.stringify(a.rawFormData.column_config || null) === - JSON.stringify(b.rawFormData.column_config || null) + JSON.stringify(b.rawFormData.column_config || null) && + a.formData.enableTimeComparison === b.formData.enableTimeComparison ); } diff --git a/superset-frontend/plugins/plugin-chart-table/test/testData.ts b/superset-frontend/plugins/plugin-chart-table/test/testData.ts index 24abc3381e..af2fbe5a65 100644 --- a/superset-frontend/plugins/plugin-chart-table/test/testData.ts +++ b/superset-frontend/plugins/plugin-chart-table/test/testData.ts @@ -84,6 +84,7 @@ const basicQueryResult: ChartDataResponseResult = { status: 'success', from_dttm: null, to_dttm: null, + instant_time_comparison_range: null, }; /** diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 611f7af597..34731af571 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -26,6 +26,7 @@ from marshmallow.validate import Length, Range from superset import app from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.constants import InstantTimeComparison from superset.db_engine_specs.base import builtin_time_grains from superset.tags.models import TagType from superset.utils import pandas_postprocessing, schema as utils @@ -948,6 +949,14 @@ class ChartDataFilterSchema(Schema): ) +class InstantTimeComparisonInfoSchema(Schema): + range = fields.String( + metadata={"description": "Type of time comparison to be used"}, + validate=validate.OneOf(choices=[ran.value for ran in InstantTimeComparison]), + ) + filter = fields.Nested(ChartDataFilterSchema, allow_none=True) + + class ChartDataExtrasSchema(Schema): relative_start = fields.String( metadata={ @@ -994,7 +1003,8 @@ class ChartDataExtrasSchema(Schema): metadata={ "description": "This is only set using the new time comparison controls " "that is made available in some plugins behind the experimental " - "feature flag." + "feature flag. If passed as extra, the time range will be changed inside this" + " query object." }, allow_none=True, ) @@ -1350,6 +1360,14 @@ class ChartDataQueryObjectSchema(Schema): fields.String(), allow_none=True, ) + instant_time_comparison_info = fields.Nested( + InstantTimeComparisonInfoSchema, + metadata={ + "description": "Extra parameters to use instant time comparison" + " with JOINs using a single query" + }, + allow_none=True, + ) class ChartDataQueryContextSchema(Schema): diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index d8b5bea4bb..77f84989b1 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -197,6 +197,9 @@ class QueryContextProcessor: "from_dttm": query_obj.from_dttm, "to_dttm": query_obj.to_dttm, "label_map": label_map, + "instant_time_comparison_range": query_obj.extras.get( + "instant_time_comparison_range" + ), } def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 5109c465e0..77f3a08ce8 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -107,6 +107,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes time_shift: str | None time_range: str | None to_dttm: datetime | None + instant_time_comparison_info: dict[str, Any] | None def __init__( # pylint: disable=too-many-locals self, @@ -132,6 +133,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes series_limit_metric: Metric | None = None, time_range: str | None = None, time_shift: str | None = None, + instant_time_comparison_info: dict[str, Any] | None = None, **kwargs: Any, ): self._set_annotation_layers(annotation_layers) @@ -161,6 +163,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes self.time_offsets = kwargs.get("time_offsets", []) self.inner_from_dttm = kwargs.get("inner_from_dttm") self.inner_to_dttm = kwargs.get("inner_to_dttm") + self.instant_time_comparison_info = instant_time_comparison_info self._rename_deprecated_fields(kwargs) self._move_deprecated_extra_fields(kwargs) @@ -335,6 +338,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes "series_limit_metric": self.series_limit_metric, "to_dttm": self.to_dttm, "time_shift": self.time_shift, + "instant_time_comparison_info": self.instant_time_comparison_info, } return query_object_dict diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 089b9c2f28..b6e14ce62e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,6 +18,7 @@ from __future__ import annotations import builtins +import copy import dataclasses import json import logging @@ -81,7 +82,7 @@ from superset.connectors.sqla.utils import ( get_physical_table_metadata, get_virtual_table_metadata, ) -from superset.constants import EMPTY_STRING, NULL_STRING +from superset.constants import EMPTY_STRING, InstantTimeComparison, NULL_STRING from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import ( ColumnNotFoundException, @@ -105,6 +106,7 @@ from superset.models.helpers import ( ImportExportMixin, QueryResult, QueryStringExtended, + SqlaQuery, validate_adhoc_subquery, ) from superset.models.slice import Slice @@ -120,7 +122,7 @@ from superset.superset_typing import ( ) from superset.utils import core as utils from superset.utils.backports import StrEnum -from superset.utils.core import GenericDataType, MediumText +from superset.utils.core import FilterOperator, GenericDataType, MediumText config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -1413,24 +1415,138 @@ class SqlaTable( def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(table=self, database=self.database, **kwargs) + def extract_column_names(self, final_selected_columns: Any) -> list[str]: + column_names = [] + for selected_col in final_selected_columns: + # The key attribute usually holds the name or alias of the column + column_name = selected_col.key if hasattr(selected_col, "key") else None + # If the column has a name attribute, use it as a fallback + if not column_name and hasattr(selected_col, "name"): + column_name = selected_col.name + # For labeled elements, the name is stored in the 'name' attribute + if hasattr(selected_col, "name"): + column_name = selected_col.name + # Append the extracted name to the list + if column_name: + column_names.append(column_name) + return column_names + + def process_time_compare_join( # pylint: disable=too-many-locals + self, + query_obj: QueryObjectDict, + sqlaq: SqlaQuery, + mutate: bool, + instant_time_comparison_info: dict[str, Any], + ) -> tuple[str, list[str]]: + query_obj_clone = copy.copy(query_obj) + final_query_sql = "" + query_obj_clone["row_limit"] = None + query_obj_clone["row_offset"] = None + instant_time_comparison_range = instant_time_comparison_info.get("range") + if instant_time_comparison_range == InstantTimeComparison.CUSTOM: + custom_filter = instant_time_comparison_info.get("filter", {}) + temporal_filters = [ + filter["col"] + for filter in query_obj_clone.get("filter", {}) + if filter.get("op", None) == FilterOperator.TEMPORAL_RANGE + ] + non_temporal_filters = [ + filter["col"] + for filter in query_obj_clone.get("filter", {}) + if filter.get("op", None) != FilterOperator.TEMPORAL_RANGE + ] + if len(temporal_filters) > 0: + # Edit the firt temporal filter to include the custom filter + temporal_filters[0] = custom_filter + + new_filters = temporal_filters + non_temporal_filters + query_obj_clone["filter"] = new_filters + if instant_time_comparison_range != InstantTimeComparison.CUSTOM: + query_obj_clone["extras"] = { + **query_obj_clone.get("extras", {}), + "instant_time_comparison_range": instant_time_comparison_range, + } + sqlaq_2 = self.get_sqla_query(**query_obj_clone) + join_columns = query_obj_clone.get("columns") or [] + sqla_query_a = sqlaq.sqla_query + sqla_query_b = sqlaq_2.sqla_query + sqla_query_b_subquery = sqla_query_b.subquery() + query_a_cte = sqla_query_a.cte("query_a_results") + column_names_a = [column.key for column in sqla_query_a.c] + exclude_columns_b = set(query_obj_clone.get("columns") or []) + selected_columns_a = [query_a_cte.c[col].label(col) for col in column_names_a] + # Renamed columns from Query B (with "prev_" prefix) + selected_columns_b = [ + sqla_query_b_subquery.c[col].label(f"prev_{col}") + for col in sqla_query_b_subquery.c.keys() + if col not in exclude_columns_b + ] + # Combine selected columns from both queries + final_selected_columns = selected_columns_a + selected_columns_b + if join_columns and not query_obj_clone.get("is_rowcount"): + # Proceed with JOIN operation as before since join_columns is not empty + join_conditions = [ + sqla_query_b_subquery.c[col] == query_a_cte.c[col] + for col in join_columns + if col in sqla_query_b_subquery.c and col in query_a_cte.c + ] + final_query = sa.select(*final_selected_columns).select_from( + sqla_query_b_subquery.join(query_a_cte, sa.and_(*join_conditions)) + ) + else: + final_query = sa.select(*final_selected_columns).select_from( + sqla_query_b_subquery.join( + query_a_cte, sa.literal(True) == sa.literal(True) + ) + ) + final_query_sql = self.database.compile_sqla_query(final_query) + final_query_sql = self._apply_cte(final_query_sql, sqlaq.cte) + final_query_sql = sqlparse.format(final_query_sql, reindent=True) + if mutate: + final_query_sql = self.mutate_query_from_config(final_query_sql) + + labels_expected = self.extract_column_names(final_selected_columns) + + return final_query_sql, labels_expected + def get_query_str_extended( self, query_obj: QueryObjectDict, mutate: bool = True, ) -> QueryStringExtended: - sqlaq = self.get_sqla_query(**query_obj) + # So we don't mutate the original query_obj + query_obj_clone = copy.copy(query_obj) + instant_time_comparison_info = query_obj.get("instant_time_comparison_info") + query_obj_clone.pop("instant_time_comparison_info", None) + sqlaq = self.get_sqla_query(**query_obj_clone) sql = self.database.compile_sqla_query(sqlaq.sqla_query) sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) + if mutate: sql = self.mutate_query_from_config(sql) + + if ( + is_feature_enabled("CHART_PLUGINS_EXPERIMENTAL") + and instant_time_comparison_info + ): + ( + final_query_sql, + labels_expected, + ) = self.process_time_compare_join( + query_obj_clone, sqlaq, mutate, instant_time_comparison_info + ) + else: + final_query_sql = sql + labels_expected = sqlaq.labels_expected + return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, applied_filter_columns=sqlaq.applied_filter_columns, rejected_filter_columns=sqlaq.rejected_filter_columns, - labels_expected=sqlaq.labels_expected, + labels_expected=labels_expected, prequeries=sqlaq.prequeries, - sql=sql, + sql=final_query_sql if final_query_sql else sql, ) def get_query_str(self, query_obj: QueryObjectDict) -> str: diff --git a/superset/constants.py b/superset/constants.py index bf4e7717d5..9af8870e2d 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -48,6 +48,7 @@ class InstantTimeComparison(StrEnum): YEAR = "y" MONTH = "m" WEEK = "w" + CUSTOM = "c" class RouteMethod: # pylint: disable=too-few-public-methods diff --git a/tests/unit_tests/connectors/__init__.py b/tests/unit_tests/connectors/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/connectors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/connectors/test_models.py b/tests/unit_tests/connectors/test_models.py new file mode 100644 index 0000000000..cf179c9dfa --- /dev/null +++ b/tests/unit_tests/connectors/test_models.py @@ -0,0 +1,383 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import datetime + +from sqlalchemy.orm.session import Session + +from superset import db +from tests.unit_tests.conftest import with_feature_flags + + +class TestInstantTimeComparisonQueryGeneration: + @staticmethod + def base_setup(session: Session): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.models.core import Database + + engine = db.session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + table = SqlaTable( + table_name="my_table", + schema="my_schema", + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + + # Common columns + columns = [ + {"column_name": "ds", "type": "DATETIME"}, + {"column_name": "gender", "type": "VARCHAR(255)"}, + {"column_name": "name", "type": "VARCHAR(255)"}, + {"column_name": "state", "type": "VARCHAR(255)"}, + ] + + # Add columns to the table + for col in columns: + TableColumn(column_name=col["column_name"], type=col["type"], table=table) + + # Common metrics + metrics = [ + {"metric_name": "count", "expression": "count(*)"}, + {"metric_name": "sum_sum", "expression": "SUM"}, + ] + + # Add metrics to the table + for metric in metrics: + SqlMetric( + metric_name=metric["metric_name"], + expression=metric["expression"], + table=table, + ) + + db.session.add(table) + db.session.flush() + + return table + + @staticmethod + def generate_base_query_obj(): + return { + "apply_fetch_values_predicate": False, + "columns": ["name"], + "extras": {"having": "", "where": ""}, + "filter": [ + {"op": "TEMPORAL_RANGE", "val": "1984-01-01 : 2024-02-14", "col": "ds"} + ], + "from_dttm": datetime.datetime(1984, 1, 1, 0, 0), + "granularity": None, + "inner_from_dttm": None, + "inner_to_dttm": None, + "is_rowcount": False, + "is_timeseries": False, + "order_desc": True, + "orderby": [("SUM(num_boys)", False)], + "row_limit": 10, + "row_offset": 0, + "series_columns": [], + "series_limit": 0, + "series_limit_metric": None, + "to_dttm": datetime.datetime(2024, 2, 14, 0, 0), + "time_shift": None, + "metrics": [ + { + "aggregate": "SUM", + "column": { + "column_name": "num_boys", + "type": "BIGINT", + "filterable": True, + "groupby": True, + "id": 334, + "is_certified": False, + "is_dttm": False, + "type_generic": 0, + }, + "datasourceWarning": False, + "expressionType": "SIMPLE", + "hasCustomLabel": False, + "label": "SUM(num_boys)", + "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4", + "sqlExpression": None, + }, + { + "aggregate": "SUM", + "column": { + "column_name": "num_girls", + "type": "BIGINT", + "filterable": True, + "groupby": True, # Note: This will need adjustment in some cases + "id": 335, + "is_certified": False, + "is_dttm": False, + "type_generic": 0, + }, + "datasourceWarning": False, + "expressionType": "SIMPLE", + "hasCustomLabel": False, + "label": "SUM(num_girls)", + "optionName": "metric_5gyhtmyfw1t_d42py86jpco", + "sqlExpression": None, + }, + ], + "instant_time_comparison_info": { + "range": "y", + }, + } + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_query(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0) + SELECT query_a_results.name AS name, + query_a_results."SUM(num_boys)" AS "SUM(num_boys)", + query_a_results."SUM(num_girls)" AS "SUM(num_girls)", + anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)", + anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)" + FROM + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1983-01-01 00:00:00' + AND ds < '2023-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC) AS anon_1 + JOIN query_a_results ON anon_1.name = query_a_results.name + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_query_no_columns(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["columns"] = [] + query_obj["metrics"][0]["column"]["groupby"] = False + query_obj["metrics"][1]["column"]["groupby"] = False + + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0) + SELECT query_a_results."SUM(num_boys)" AS "SUM(num_boys)", + query_a_results."SUM(num_girls)" AS "SUM(num_girls)", + anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)", + anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)" + FROM + (SELECT sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1983-01-01 00:00:00' + AND ds < '2023-02-14 00:00:00' + ORDER BY "SUM(num_boys)" DESC) AS anon_1 + JOIN query_a_results ON 1 = 1 + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_rowcount_query(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["is_rowcount"] = True + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT COUNT(*) AS rowcount + FROM + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0) AS rowcount_qry) + SELECT query_a_results.rowcount AS rowcount, + anon_1.rowcount AS prev_rowcount + FROM + (SELECT COUNT(*) AS rowcount + FROM + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1983-01-01 00:00:00' + AND ds < '2023-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC) AS rowcount_qry) AS anon_1 + JOIN query_a_results ON 1 = 1 + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_query_without_time_comparison(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["instant_time_comparison_info"] = None + str = table.get_query_str_extended(query_obj) + expected_str = """ + SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0 + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_query_custom_filters(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["instant_time_comparison_info"] = { + "range": "c", + "filter": { + "op": "TEMPORAL_RANGE", + "val": "1900-01-01 : 1950-02-14", + "col": "ds", + }, + } + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0) + SELECT query_a_results.name AS name, + query_a_results."SUM(num_boys)" AS "SUM(num_boys)", + query_a_results."SUM(num_girls)" AS "SUM(num_girls)", + anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)", + anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)" + FROM + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1900-01-01 00:00:00' + AND ds < '1950-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC) AS anon_1 + JOIN query_a_results ON anon_1.name = query_a_results.name + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_query_paginated(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["row_offset"] = 20 + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 20) + SELECT query_a_results.name AS name, + query_a_results."SUM(num_boys)" AS "SUM(num_boys)", + query_a_results."SUM(num_girls)" AS "SUM(num_girls)", + anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)", + anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)" + FROM + (SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1983-01-01 00:00:00' + AND ds < '2023-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC) AS anon_1 + JOIN query_a_results ON anon_1.name = query_a_results.name + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=False) + def test_ignore_if_ff_off(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + str = table.get_query_str_extended(query_obj) + expected_str = """ + SELECT name AS name, + sum(num_boys) AS "SUM(num_boys)", + sum(num_girls) AS "SUM(num_girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY name + ORDER BY "SUM(num_boys)" DESC + LIMIT 10 + OFFSET 0 + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + assert table.id == 1 + assert simplified_query1 == simplified_query2 diff --git a/tests/unit_tests/queries/query_object_test.py b/tests/unit_tests/queries/query_object_test.py index 81a654653f..f90ab8255d 100644 --- a/tests/unit_tests/queries/query_object_test.py +++ b/tests/unit_tests/queries/query_object_test.py @@ -47,6 +47,7 @@ def test_default_query_object_to_dict(): "granularity": None, "inner_from_dttm": None, "inner_to_dttm": None, + "instant_time_comparison_info": None, "is_rowcount": False, "is_timeseries": False, "metrics": None,
