This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 375c03e feat(advanced analysis): support MultiIndex column in post
processing stage (#19116)
375c03e is described below
commit 375c03e08407570bcf417acf5f3d25b28843329c
Author: Yongjie Zhao <[email protected]>
AuthorDate: Wed Mar 23 13:46:28 2022 +0800
feat(advanced analysis): support MultiIndex column in post processing stage
(#19116)
---
.../src/operators/boxplotOperator.ts | 10 +-
.../src/operators/contributionOperator.ts | 25 ++-
...{contributionOperator.ts => flattenOperator.ts} | 20 +-
.../src/operators/index.ts | 1 +
.../src/operators/pivotOperator.ts | 15 +-
.../src/operators/prophetOperator.ts | 7 +-
.../src/operators/resampleOperator.ts | 23 +--
.../src/operators/rollingWindowOperator.ts | 28 +--
.../src/operators/sortOperator.ts | 7 +-
.../src/operators/timeCompareOperator.ts | 41 ++--
.../src/operators/timeComparePivotOperator.ts | 66 +++----
.../test/utils/operators/flattenOperator.test.ts | 59 ++++++
.../test/utils/operators/pivotOperator.test.ts | 44 +----
.../test/utils/operators/resampleOperator.test.ts | 81 +-------
.../utils/operators/rollingWindowOperator.test.ts | 45 +----
.../utils/operators/timeCompareOperator.test.ts | 129 +++----------
...or.test.ts => timeComparePivotOperator.test.ts} | 147 ++++++---------
.../src/query/types/PostProcessing.ts | 71 +++++--
.../BigNumber/BigNumberWithTrendline/buildQuery.ts | 24 +--
.../src/MixedTimeseries/buildQuery.ts | 4 +-
.../src/Timeseries/buildQuery.ts | 63 ++++---
superset/charts/schemas.py | 1 +
superset/common/query_context_processor.py | 12 +-
superset/common/query_object.py | 12 +-
superset/exceptions.py | 4 +
superset/utils/pandas_postprocessing/__init__.py | 2 +
superset/utils/pandas_postprocessing/aggregate.py | 2 +-
superset/utils/pandas_postprocessing/boxplot.py | 4 +-
superset/utils/pandas_postprocessing/compare.py | 31 +--
.../utils/pandas_postprocessing/contribution.py | 6 +-
superset/utils/pandas_postprocessing/cum.py | 31 +--
superset/utils/pandas_postprocessing/diff.py | 2 +-
superset/utils/pandas_postprocessing/flatten.py | 81 ++++++++
superset/utils/pandas_postprocessing/geography.py | 8 +-
superset/utils/pandas_postprocessing/pivot.py | 8 +-
superset/utils/pandas_postprocessing/prophet.py | 16 +-
superset/utils/pandas_postprocessing/resample.py | 41 ++--
superset/utils/pandas_postprocessing/rolling.py | 34 +---
superset/utils/pandas_postprocessing/select.py | 2 +-
superset/utils/pandas_postprocessing/sort.py | 2 +-
superset/utils/pandas_postprocessing/utils.py | 40 ++--
tests/common/query_context_generator.py | 9 +-
tests/integration_tests/query_context_tests.py | 12 +-
.../pandas_postprocessing/test_boxplot.py | 10 +-
.../pandas_postprocessing/test_compare.py | 209 +++++++++++++++++++--
.../pandas_postprocessing/test_contribution.py | 6 +-
tests/unit_tests/pandas_postprocessing/test_cum.py | 115 +++++++++---
.../unit_tests/pandas_postprocessing/test_diff.py | 4 +-
.../pandas_postprocessing/test_flatten.py | 64 +++++++
.../unit_tests/pandas_postprocessing/test_pivot.py | 8 +-
.../pandas_postprocessing/test_prophet.py | 14 +-
.../pandas_postprocessing/test_resample.py | 170 +++++++++++------
.../pandas_postprocessing/test_rolling.py | 157 ++++++++++++----
.../pandas_postprocessing/test_select.py | 6 +-
.../unit_tests/pandas_postprocessing/test_sort.py | 4 +-
55 files changed, 1164 insertions(+), 873 deletions(-)
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/boxplotOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/boxplotOperator.ts
index 9b90c12..8a4b4f0 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/boxplotOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/boxplotOperator.ts
@@ -21,16 +21,16 @@ import {
getColumnLabel,
getMetricLabel,
PostProcessingBoxplot,
+ BoxPlotQueryObjectWhiskerType,
} from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-type BoxPlotQueryObjectWhiskerType =
- PostProcessingBoxplot['options']['whisker_type'];
const PERCENTILE_REGEX = /(\d+)\/(\d+) percentiles/;
-export const boxplotOperator: PostProcessingFactory<
- PostProcessingBoxplot | undefined
-> = (formData, queryObject) => {
+export const boxplotOperator: PostProcessingFactory<PostProcessingBoxplot> = (
+ formData,
+ queryObject,
+) => {
const { groupby, whiskerOptions } = formData;
if (whiskerOptions) {
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
index 793ca87..484117c 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
@@ -19,16 +19,15 @@
import { PostProcessingContribution } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-export const contributionOperator: PostProcessingFactory<
- PostProcessingContribution | undefined
-> = (formData, queryObject) => {
- if (formData.contributionMode) {
- return {
- operation: 'contribution',
- options: {
- orientation: formData.contributionMode,
- },
- };
- }
- return undefined;
-};
+export const contributionOperator:
PostProcessingFactory<PostProcessingContribution> =
+ (formData, queryObject) => {
+ if (formData.contributionMode) {
+ return {
+ operation: 'contribution',
+ options: {
+ orientation: formData.contributionMode,
+ },
+ };
+ }
+ return undefined;
+ };
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/flattenOperator.ts
similarity index 69%
copy from
superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
copy to
superset-frontend/packages/superset-ui-chart-controls/src/operators/flattenOperator.ts
index 793ca87..1348f4b 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/contributionOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/flattenOperator.ts
@@ -1,3 +1,4 @@
+/* eslint-disable camelcase */
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -16,19 +17,10 @@
* specific language governing permissions and limitationsxw
* under the License.
*/
-import { PostProcessingContribution } from '@superset-ui/core';
+import { PostProcessingFlatten } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-export const contributionOperator: PostProcessingFactory<
- PostProcessingContribution | undefined
-> = (formData, queryObject) => {
- if (formData.contributionMode) {
- return {
- operation: 'contribution',
- options: {
- orientation: formData.contributionMode,
- },
- };
- }
- return undefined;
-};
+export const flattenOperator: PostProcessingFactory<PostProcessingFlatten> = (
+ formData,
+ queryObject,
+) => ({ operation: 'flatten' });
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/index.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/index.ts
index 95aeb21..28e7e70 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/index.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/index.ts
@@ -26,4 +26,5 @@ export { resampleOperator } from './resampleOperator';
export { contributionOperator } from './contributionOperator';
export { prophetOperator } from './prophetOperator';
export { boxplotOperator } from './boxplotOperator';
+export { flattenOperator } from './flattenOperator';
export * from './utils';
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/pivotOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/pivotOperator.ts
index e1e1fde..a5bf20d 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/pivotOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/pivotOperator.ts
@@ -24,19 +24,14 @@ import {
PostProcessingPivot,
} from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-import { isValidTimeCompare } from './utils';
-import { timeComparePivotOperator } from './timeComparePivotOperator';
-export const pivotOperator: PostProcessingFactory<
- PostProcessingPivot | undefined
-> = (formData, queryObject) => {
+export const pivotOperator: PostProcessingFactory<PostProcessingPivot> = (
+ formData,
+ queryObject,
+) => {
const metricLabels = ensureIsArray(queryObject.metrics).map(getMetricLabel);
const { x_axis: xAxis } = formData;
if ((xAxis || queryObject.is_timeseries) && metricLabels.length) {
- if (isValidTimeCompare(formData, queryObject)) {
- return timeComparePivotOperator(formData, queryObject);
- }
-
return {
operation: 'pivot',
options: {
@@ -48,6 +43,8 @@ export const pivotOperator: PostProcessingFactory<
metricLabels.map(metric => [metric, { operator: 'mean' }]),
),
drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
},
};
}
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/prophetOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/prophetOperator.ts
index 640cb8b..297d84e 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/prophetOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/prophetOperator.ts
@@ -19,9 +19,10 @@
import { DTTM_ALIAS, PostProcessingProphet } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-export const prophetOperator: PostProcessingFactory<
- PostProcessingProphet | undefined
-> = (formData, queryObject) => {
+export const prophetOperator: PostProcessingFactory<PostProcessingProphet> = (
+ formData,
+ queryObject,
+) => {
if (formData.forecastEnabled) {
return {
operation: 'prophet',
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/resampleOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/resampleOperator.ts
index d639e19..2306ea3 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/resampleOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/resampleOperator.ts
@@ -17,36 +17,23 @@
* specific language governing permissions and limitationsxw
* under the License.
*/
-import {
- DTTM_ALIAS,
- ensureIsArray,
- isPhysicalColumn,
- PostProcessingResample,
-} from '@superset-ui/core';
+import { PostProcessingResample } from '@superset-ui/core';
import { PostProcessingFactory } from './types';
-export const resampleOperator: PostProcessingFactory<
- PostProcessingResample | undefined
-> = (formData, queryObject) => {
+export const resampleOperator: PostProcessingFactory<PostProcessingResample> =
(
+ formData,
+ queryObject,
+) => {
const resampleZeroFill = formData.resample_method === 'zerofill';
const resampleMethod = resampleZeroFill ? 'asfreq' :
formData.resample_method;
const resampleRule = formData.resample_rule;
if (resampleMethod && resampleRule) {
- const groupby_columns = ensureIsArray(queryObject.columns).map(column => {
- if (isPhysicalColumn(column)) {
- return column;
- }
- return column.label;
- });
-
return {
operation: 'resample',
options: {
method: resampleMethod,
rule: resampleRule,
fill_value: resampleZeroFill ? 0 : null,
- time_column: formData.x_axis || DTTM_ALIAS,
- groupby_columns,
},
};
}
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/rollingWindowOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/rollingWindowOperator.ts
index d4c04ec..563b3e0 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/rollingWindowOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/rollingWindowOperator.ts
@@ -18,39 +18,25 @@
* under the License.
*/
import {
- ComparisionType,
ensureIsArray,
ensureIsInt,
PostProcessingCum,
PostProcessingRolling,
RollingType,
} from '@superset-ui/core';
-import {
- getMetricOffsetsMap,
- isValidTimeCompare,
- TIME_COMPARISON_SEPARATOR,
-} from './utils';
+import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
export const rollingWindowOperator: PostProcessingFactory<
- PostProcessingRolling | PostProcessingCum | undefined
+ PostProcessingRolling | PostProcessingCum
> = (formData, queryObject) => {
let columns: (string | undefined)[];
if (isValidTimeCompare(formData, queryObject)) {
const metricsMap = getMetricOffsetsMap(formData, queryObject);
- const comparisonType = formData.comparison_type;
- if (comparisonType === ComparisionType.Values) {
- // time compare type: actual values
- columns = [
- ...Array.from(metricsMap.values()),
- ...Array.from(metricsMap.keys()),
- ];
- } else {
- // time compare type: difference / percentage / ratio
- columns = Array.from(metricsMap.entries()).map(([offset, metric]) =>
- [comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
- );
- }
+ columns = [
+ ...Array.from(metricsMap.values()),
+ ...Array.from(metricsMap.keys()),
+ ];
} else {
columns = ensureIsArray(queryObject.metrics).map(metric => {
if (typeof metric === 'string') {
@@ -67,7 +53,6 @@ export const rollingWindowOperator: PostProcessingFactory<
options: {
operator: 'sum',
columns: columnsMap,
- is_pivot_df: true,
},
};
}
@@ -84,7 +69,6 @@ export const rollingWindowOperator: PostProcessingFactory<
window: ensureIsInt(formData.rolling_periods, 1),
min_periods: ensureIsInt(formData.min_periods, 0),
columns: columnsMap,
- is_pivot_df: true,
},
};
}
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/sortOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/sortOperator.ts
index 9443bb7..277d2df 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/sortOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/sortOperator.ts
@@ -20,9 +20,10 @@
import { DTTM_ALIAS, PostProcessingSort, RollingType } from
'@superset-ui/core';
import { PostProcessingFactory } from './types';
-export const sortOperator: PostProcessingFactory<
- PostProcessingSort | undefined
-> = (formData, queryObject) => {
+export const sortOperator: PostProcessingFactory<PostProcessingSort> = (
+ formData,
+ queryObject,
+) => {
const { x_axis: xAxis } = formData;
if (
(xAxis || queryObject.is_timeseries) &&
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeCompareOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeCompareOperator.ts
index 55d8a82..ec62384 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeCompareOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeCompareOperator.ts
@@ -21,26 +21,25 @@ import { ComparisionType, PostProcessingCompare } from
'@superset-ui/core';
import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
-export const timeCompareOperator: PostProcessingFactory<
- PostProcessingCompare | undefined
-> = (formData, queryObject) => {
- const comparisonType = formData.comparison_type;
- const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
+export const timeCompareOperator: PostProcessingFactory<PostProcessingCompare>
=
+ (formData, queryObject) => {
+ const comparisonType = formData.comparison_type;
+ const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
- if (
- isValidTimeCompare(formData, queryObject) &&
- comparisonType !== ComparisionType.Values
- ) {
- return {
- operation: 'compare',
- options: {
- source_columns: Array.from(metricOffsetMap.values()),
- compare_columns: Array.from(metricOffsetMap.keys()),
- compare_type: comparisonType,
- drop_original_columns: true,
- },
- };
- }
+ if (
+ isValidTimeCompare(formData, queryObject) &&
+ comparisonType !== ComparisionType.Values
+ ) {
+ return {
+ operation: 'compare',
+ options: {
+ source_columns: Array.from(metricOffsetMap.values()),
+ compare_columns: Array.from(metricOffsetMap.keys()),
+ compare_type: comparisonType,
+ drop_original_columns: true,
+ },
+ };
+ }
- return undefined;
-};
+ return undefined;
+ };
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeComparePivotOperator.ts
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeComparePivotOperator.ts
index 9e16d29..44a1825 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeComparePivotOperator.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/src/operators/timeComparePivotOperator.ts
@@ -18,54 +18,40 @@
* under the License.
*/
import {
- ComparisionType,
DTTM_ALIAS,
ensureIsArray,
getColumnLabel,
NumpyFunction,
PostProcessingPivot,
} from '@superset-ui/core';
-import {
- getMetricOffsetsMap,
- isValidTimeCompare,
- TIME_COMPARISON_SEPARATOR,
-} from './utils';
+import { getMetricOffsetsMap, isValidTimeCompare } from './utils';
import { PostProcessingFactory } from './types';
-export const timeComparePivotOperator: PostProcessingFactory<
- PostProcessingPivot | undefined
-> = (formData, queryObject) => {
- const comparisonType = formData.comparison_type;
- const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
+export const timeComparePivotOperator:
PostProcessingFactory<PostProcessingPivot> =
+ (formData, queryObject) => {
+ const metricOffsetMap = getMetricOffsetsMap(formData, queryObject);
- if (isValidTimeCompare(formData, queryObject)) {
- const valuesAgg = Object.fromEntries(
- [...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric => [
- metric,
- // use the 'mean' aggregates to avoid drop NaN
- { operator: 'mean' as NumpyFunction },
- ]),
- );
- const changeAgg = Object.fromEntries(
- [...metricOffsetMap.entries()]
- .map(([offset, metric]) =>
- [comparisonType, metric, offset].join(TIME_COMPARISON_SEPARATOR),
- )
- // use the 'mean' aggregates to avoid drop NaN
- .map(metric => [metric, { operator: 'mean' as NumpyFunction }]),
- );
+ if (isValidTimeCompare(formData, queryObject)) {
+ const aggregates = Object.fromEntries(
+ [...metricOffsetMap.values(), ...metricOffsetMap.keys()].map(metric =>
[
+ metric,
+ // use the 'mean' aggregates to avoid drop NaN
+ { operator: 'mean' as NumpyFunction },
+ ]),
+ );
- return {
- operation: 'pivot',
- options: {
- index: [formData.x_axis || DTTM_ALIAS],
- columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
- aggregates:
- comparisonType === ComparisionType.Values ? valuesAgg : changeAgg,
- drop_missing_columns: false,
- },
- };
- }
+ return {
+ operation: 'pivot',
+ options: {
+ index: [formData.x_axis || DTTM_ALIAS],
+ columns: ensureIsArray(queryObject.columns).map(getColumnLabel),
+ drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
+ aggregates,
+ },
+ };
+ }
- return undefined;
-};
+ return undefined;
+ };
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/flattenOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/flattenOperator.test.ts
new file mode 100644
index 0000000..94a9b00
--- /dev/null
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/flattenOperator.test.ts
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import { QueryObject, SqlaFormData } from '@superset-ui/core';
+import { flattenOperator } from '@superset-ui/chart-controls';
+
+const formData: SqlaFormData = {
+ metrics: [
+ 'count(*)',
+ { label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
+ ],
+ time_range: '2015 : 2016',
+ granularity: 'month',
+ datasource: 'foo',
+ viz_type: 'table',
+};
+const queryObject: QueryObject = {
+ metrics: [
+ 'count(*)',
+ { label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
+ ],
+ time_range: '2015 : 2016',
+ granularity: 'month',
+ post_processing: [
+ {
+ operation: 'pivot',
+ options: {
+ index: ['__timestamp'],
+ columns: ['nation'],
+ aggregates: {
+ 'count(*)': {
+ operator: 'sum',
+ },
+ },
+ },
+ },
+ ],
+};
+
+test('should do flattenOperator', () => {
+ expect(flattenOperator(formData, queryObject)).toEqual({
+ operation: 'flatten',
+ });
+});
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
index 5054049..b75385a 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
@@ -80,6 +80,8 @@ test('pivot by __timestamp without groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
},
});
});
@@ -101,6 +103,8 @@ test('pivot by __timestamp with groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
},
});
});
@@ -127,44 +131,8 @@ test('pivot by x_axis with groupby', () => {
'sum(val)': { operator: 'mean' },
},
drop_missing_columns: false,
- },
- });
-});
-
-test('timecompare in formdata', () => {
- expect(
- pivotOperator(
- {
- ...formData,
- comparison_type: 'values',
- time_compare: ['1 year ago', '1 year later'],
- },
- {
- ...queryObject,
- columns: ['foo', 'bar'],
- is_timeseries: true,
- },
- ),
- ).toEqual({
- operation: 'pivot',
- options: {
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'count(*)__1 year ago': { operator: 'mean' },
- 'count(*)__1 year later': { operator: 'mean' },
- 'sum(val)': {
- operator: 'mean',
- },
- 'sum(val)__1 year ago': {
- operator: 'mean',
- },
- 'sum(val)__1 year later': {
- operator: 'mean',
- },
- },
- drop_missing_columns: false,
- columns: ['foo', 'bar'],
- index: ['__timestamp'],
+ flatten_columns: false,
+ reset_index: false,
},
});
});
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/resampleOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/resampleOperator.test.ts
index a562dbb..271e63b 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/resampleOperator.test.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/resampleOperator.test.ts
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-import { AdhocColumn, QueryObject, SqlaFormData } from '@superset-ui/core';
+import { QueryObject, SqlaFormData } from '@superset-ui/core';
import { resampleOperator } from '@superset-ui/chart-controls';
const formData: SqlaFormData = {
@@ -74,8 +74,6 @@ test('should do resample on implicit time column', () => {
method: 'ffill',
rule: '1D',
fill_value: null,
- time_column: '__timestamp',
- groupby_columns: [],
},
});
});
@@ -95,10 +93,8 @@ test('should do resample on x-axis', () => {
operation: 'resample',
options: {
fill_value: null,
- groupby_columns: [],
method: 'ffill',
rule: '1D',
- time_column: 'ds',
},
});
});
@@ -115,81 +111,6 @@ test('should do zerofill resample', () => {
method: 'asfreq',
rule: '1D',
fill_value: 0,
- time_column: '__timestamp',
- groupby_columns: [],
- },
- });
-});
-
-test('should append physical column to resample', () => {
- expect(
- resampleOperator(
- { ...formData, resample_method: 'zerofill', resample_rule: '1D' },
- { ...queryObject, columns: ['column1', 'column2'] },
- ),
- ).toEqual({
- operation: 'resample',
- options: {
- method: 'asfreq',
- rule: '1D',
- fill_value: 0,
- time_column: '__timestamp',
- groupby_columns: ['column1', 'column2'],
- },
- });
-});
-
-test('should append label of adhoc column and physical column to resample', ()
=> {
- expect(
- resampleOperator(
- { ...formData, resample_method: 'zerofill', resample_rule: '1D' },
- {
- ...queryObject,
- columns: [
- {
- hasCustomLabel: true,
- label: 'concat_a_b',
- expressionType: 'SQL',
- sqlExpression: "'a' + 'b'",
- } as AdhocColumn,
- 'column2',
- ],
- },
- ),
- ).toEqual({
- operation: 'resample',
- options: {
- method: 'asfreq',
- rule: '1D',
- fill_value: 0,
- time_column: '__timestamp',
- groupby_columns: ['concat_a_b', 'column2'],
- },
- });
-});
-
-test('should append `undefined` if adhoc non-existing label', () => {
- expect(
- resampleOperator(
- { ...formData, resample_method: 'zerofill', resample_rule: '1D' },
- {
- ...queryObject,
- columns: [
- {
- sqlExpression: "'a' + 'b'",
- } as AdhocColumn,
- 'column2',
- ],
- },
- ),
- ).toEqual({
- operation: 'resample',
- options: {
- method: 'asfreq',
- rule: '1D',
- fill_value: 0,
- time_column: '__timestamp',
- groupby_columns: [undefined, 'column2'],
},
});
});
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/rollingWindowOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/rollingWindowOperator.test.ts
index 82e786a..eec2bb7 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/rollingWindowOperator.test.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/rollingWindowOperator.test.ts
@@ -79,7 +79,6 @@ test('rolling_type: cumsum', () => {
'count(*)': 'count(*)',
'sum(val)': 'sum(val)',
},
- is_pivot_df: true,
},
});
});
@@ -102,42 +101,13 @@ test('rolling_type: sum/mean/std', () => {
'count(*)': 'count(*)',
'sum(val)': 'sum(val)',
},
- is_pivot_df: true,
},
});
});
});
-test('rolling window and "actual values" in the time compare', () => {
- expect(
- rollingWindowOperator(
- {
- ...formData,
- rolling_type: 'cumsum',
- comparison_type: 'values',
- time_compare: ['1 year ago', '1 year later'],
- },
- queryObject,
- ),
- ).toEqual({
- operation: 'cum',
- options: {
- operator: 'sum',
- columns: {
- 'count(*)': 'count(*)',
- 'count(*)__1 year ago': 'count(*)__1 year ago',
- 'count(*)__1 year later': 'count(*)__1 year later',
- 'sum(val)': 'sum(val)',
- 'sum(val)__1 year ago': 'sum(val)__1 year ago',
- 'sum(val)__1 year later': 'sum(val)__1 year later',
- },
- is_pivot_df: true,
- },
- });
-});
-
-test('rolling window and "difference / percentage / ratio" in the time
compare', () => {
- const comparisionTypes = ['difference', 'percentage', 'ratio'];
+test('should append compared metrics when sets time compare type', () => {
+ const comparisionTypes = ['values', 'difference', 'percentage', 'ratio'];
comparisionTypes.forEach(cType => {
expect(
rollingWindowOperator(
@@ -154,12 +124,13 @@ test('rolling window and "difference / percentage /
ratio" in the time compare',
options: {
operator: 'sum',
columns: {
- [`${cType}__count(*)__count(*)__1 year ago`]:
`${cType}__count(*)__count(*)__1 year ago`,
- [`${cType}__count(*)__count(*)__1 year later`]:
`${cType}__count(*)__count(*)__1 year later`,
- [`${cType}__sum(val)__sum(val)__1 year ago`]:
`${cType}__sum(val)__sum(val)__1 year ago`,
- [`${cType}__sum(val)__sum(val)__1 year later`]:
`${cType}__sum(val)__sum(val)__1 year later`,
+ 'count(*)': 'count(*)',
+ 'count(*)__1 year ago': 'count(*)__1 year ago',
+ 'count(*)__1 year later': 'count(*)__1 year later',
+ 'sum(val)': 'sum(val)',
+ 'sum(val)__1 year ago': 'sum(val)__1 year ago',
+ 'sum(val)__1 year later': 'sum(val)__1 year later',
},
- is_pivot_df: true,
},
});
});
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeCompareOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeCompareOperator.test.ts
index c2fcb75..197ccee 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeCompareOperator.test.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeCompareOperator.test.ts
@@ -17,17 +17,23 @@
* under the License.
*/
import { QueryObject, SqlaFormData } from '@superset-ui/core';
-import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
+import { timeCompareOperator } from '../../../src';
const formData: SqlaFormData = {
- metrics: ['count(*)'],
+ metrics: [
+ 'count(*)',
+ { label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
+ ],
time_range: '2015 : 2016',
granularity: 'month',
datasource: 'foo',
viz_type: 'table',
};
const queryObject: QueryObject = {
- metrics: ['count(*)'],
+ metrics: [
+ 'count(*)',
+ { label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
+ ],
time_range: '2015 : 2016',
granularity: 'month',
post_processing: [
@@ -40,21 +46,26 @@ const queryObject: QueryObject = {
'count(*)': {
operator: 'mean',
},
+ 'sum(val)': {
+ operator: 'mean',
+ },
},
drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
},
},
{
operation: 'aggregation',
options: {
groupby: ['col1'],
- aggregates: 'count',
+ aggregates: {},
},
},
],
};
-test('time compare: skip transformation', () => {
+test('should skip CompareOperator', () => {
expect(timeCompareOperator(formData, queryObject)).toEqual(undefined);
expect(
timeCompareOperator({ ...formData, time_compare: [] }, queryObject),
@@ -80,7 +91,7 @@ test('time compare: skip transformation', () => {
).toEqual(undefined);
});
-test('time compare: difference/percentage/ratio', () => {
+test('should generate difference/percentage/ratio CompareOperator', () => {
const comparisionTypes = ['difference', 'percentage', 'ratio'];
comparisionTypes.forEach(cType => {
expect(
@@ -95,108 +106,16 @@ test('time compare: difference/percentage/ratio', () => {
).toEqual({
operation: 'compare',
options: {
- source_columns: ['count(*)', 'count(*)'],
- compare_columns: ['count(*)__1 year ago', 'count(*)__1 year later'],
+ source_columns: ['count(*)', 'count(*)', 'sum(val)', 'sum(val)'],
+ compare_columns: [
+ 'count(*)__1 year ago',
+ 'count(*)__1 year later',
+ 'sum(val)__1 year ago',
+ 'sum(val)__1 year later',
+ ],
compare_type: cType,
drop_original_columns: true,
},
});
});
});
-
-test('time compare pivot: skip transformation', () => {
- expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
- expect(
- timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
- ).toEqual(undefined);
- expect(
- timeComparePivotOperator(
- { ...formData, comparison_type: null },
- queryObject,
- ),
- ).toEqual(undefined);
- expect(
- timeCompareOperator(
- { ...formData, comparison_type: 'foobar' },
- queryObject,
- ),
- ).toEqual(undefined);
-});
-
-test('time compare pivot: values', () => {
- expect(
- timeComparePivotOperator(
- {
- ...formData,
- comparison_type: 'values',
- time_compare: ['1 year ago', '1 year later'],
- },
- queryObject,
- ),
- ).toEqual({
- operation: 'pivot',
- options: {
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'count(*)__1 year ago': { operator: 'mean' },
- 'count(*)__1 year later': { operator: 'mean' },
- },
- drop_missing_columns: false,
- columns: [],
- index: ['__timestamp'],
- },
- });
-});
-
-test('time compare pivot: difference/percentage/ratio', () => {
- const comparisionTypes = ['difference', 'percentage', 'ratio'];
- comparisionTypes.forEach(cType => {
- expect(
- timeComparePivotOperator(
- {
- ...formData,
- comparison_type: cType,
- time_compare: ['1 year ago', '1 year later'],
- },
- queryObject,
- ),
- ).toEqual({
- operation: 'pivot',
- options: {
- aggregates: {
- [`${cType}__count(*)__count(*)__1 year ago`]: { operator: 'mean' },
- [`${cType}__count(*)__count(*)__1 year later`]: { operator: 'mean' },
- },
- drop_missing_columns: false,
- columns: [],
- index: ['__timestamp'],
- },
- });
- });
-});
-
-test('time compare pivot on x-axis', () => {
- expect(
- timeComparePivotOperator(
- {
- ...formData,
- comparison_type: 'values',
- time_compare: ['1 year ago', '1 year later'],
- x_axis: 'ds',
- },
- queryObject,
- ),
- ).toEqual({
- operation: 'pivot',
- options: {
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'count(*)__1 year ago': { operator: 'mean' },
- 'count(*)__1 year later': { operator: 'mean' },
- },
- drop_missing_columns: false,
- columns: [],
- index: ['ds'],
- },
- });
-});
diff --git
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeComparePivotOperator.test.ts
similarity index 55%
copy from
superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
copy to
superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeComparePivotOperator.test.ts
index 5054049..fcf8ea6 100644
---
a/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/pivotOperator.test.ts
+++
b/superset-frontend/packages/superset-ui-chart-controls/test/utils/operators/timeComparePivotOperator.test.ts
@@ -17,7 +17,7 @@
* under the License.
*/
import { QueryObject, SqlaFormData } from '@superset-ui/core';
-import { pivotOperator } from '../../../src';
+import { timeCompareOperator, timeComparePivotOperator } from '../../../src';
const formData: SqlaFormData = {
metrics: [
@@ -34,116 +34,81 @@ const queryObject: QueryObject = {
'count(*)',
{ label: 'sum(val)', expressionType: 'SQL', sqlExpression: 'sum(val)' },
],
+ columns: ['foo', 'bar'],
time_range: '2015 : 2016',
granularity: 'month',
- post_processing: [
- {
- operation: 'pivot',
- options: {
- index: ['__timestamp'],
- columns: ['nation'],
- aggregates: {
- 'count(*)': {
- operator: 'mean',
- },
- },
- drop_missing_columns: false,
- },
- },
- ],
+ post_processing: [],
};
-test('skip pivot', () => {
- expect(pivotOperator(formData, queryObject)).toEqual(undefined);
+test('should skip pivot', () => {
+ expect(timeComparePivotOperator(formData, queryObject)).toEqual(undefined);
expect(
- pivotOperator(formData, { ...queryObject, is_timeseries: false }),
+ timeComparePivotOperator({ ...formData, time_compare: [] }, queryObject),
).toEqual(undefined);
expect(
- pivotOperator(formData, {
- ...queryObject,
- is_timeseries: true,
- metrics: [],
- }),
+ timeComparePivotOperator(
+ { ...formData, comparison_type: null },
+ queryObject,
+ ),
).toEqual(undefined);
-});
-
-test('pivot by __timestamp without groupby', () => {
expect(
- pivotOperator(formData, { ...queryObject, is_timeseries: true }),
- ).toEqual({
- operation: 'pivot',
- options: {
- index: ['__timestamp'],
- columns: [],
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'sum(val)': { operator: 'mean' },
- },
- drop_missing_columns: false,
- },
- });
-});
-
-test('pivot by __timestamp with groupby', () => {
- expect(
- pivotOperator(formData, {
- ...queryObject,
- columns: ['foo', 'bar'],
- is_timeseries: true,
- }),
- ).toEqual({
- operation: 'pivot',
- options: {
- index: ['__timestamp'],
- columns: ['foo', 'bar'],
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'sum(val)': { operator: 'mean' },
- },
- drop_missing_columns: false,
- },
- });
+ timeCompareOperator(
+ { ...formData, comparison_type: 'foobar' },
+ queryObject,
+ ),
+ ).toEqual(undefined);
});
-test('pivot by x_axis with groupby', () => {
- expect(
- pivotOperator(
- {
- ...formData,
- x_axis: 'baz',
- },
- {
- ...queryObject,
+test('should pivot on any type of timeCompare', () => {
+ const anyTimeCompareTypes = ['values', 'difference', 'percentage', 'ratio'];
+ anyTimeCompareTypes.forEach(cType => {
+ expect(
+ timeComparePivotOperator(
+ {
+ ...formData,
+ comparison_type: cType,
+ time_compare: ['1 year ago', '1 year later'],
+ },
+ {
+ ...queryObject,
+ is_timeseries: true,
+ },
+ ),
+ ).toEqual({
+ operation: 'pivot',
+ options: {
+ aggregates: {
+ 'count(*)': { operator: 'mean' },
+ 'count(*)__1 year ago': { operator: 'mean' },
+ 'count(*)__1 year later': { operator: 'mean' },
+ 'sum(val)': { operator: 'mean' },
+ 'sum(val)__1 year ago': {
+ operator: 'mean',
+ },
+ 'sum(val)__1 year later': {
+ operator: 'mean',
+ },
+ },
+ drop_missing_columns: false,
+ flatten_columns: false,
+ reset_index: false,
columns: ['foo', 'bar'],
+ index: ['__timestamp'],
},
- ),
- ).toEqual({
- operation: 'pivot',
- options: {
- index: ['baz'],
- columns: ['foo', 'bar'],
- aggregates: {
- 'count(*)': { operator: 'mean' },
- 'sum(val)': { operator: 'mean' },
- },
- drop_missing_columns: false,
- },
+ });
});
});
-test('timecompare in formdata', () => {
+test('should pivot on x-axis', () => {
expect(
- pivotOperator(
+ timeComparePivotOperator(
{
...formData,
comparison_type: 'values',
time_compare: ['1 year ago', '1 year later'],
+ x_axis: 'ds',
},
- {
- ...queryObject,
- columns: ['foo', 'bar'],
- is_timeseries: true,
- },
+ queryObject,
),
).toEqual({
operation: 'pivot',
@@ -164,7 +129,9 @@ test('timecompare in formdata', () => {
},
drop_missing_columns: false,
columns: ['foo', 'bar'],
- index: ['__timestamp'],
+ index: ['ds'],
+ flatten_columns: false,
+ reset_index: false,
},
});
});
diff --git
a/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts
b/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts
index cf2baf5..7e5ce85 100644
---
a/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts
+++
b/superset-frontend/packages/superset-ui-core/src/query/types/PostProcessing.ts
@@ -64,25 +64,34 @@ export interface Aggregates {
};
}
-export interface PostProcessingAggregation {
+export type DefaultPostProcessing = undefined;
+
+interface _PostProcessingAggregation {
operation: 'aggregation';
options: {
groupby: string[];
aggregates: Aggregates;
};
}
+export type PostProcessingAggregation =
+ | _PostProcessingAggregation
+ | DefaultPostProcessing;
-export interface PostProcessingBoxplot {
+export type BoxPlotQueryObjectWhiskerType = 'tukey' | 'min/max' | 'percentile';
+interface _PostProcessingBoxplot {
operation: 'boxplot';
options: {
groupby: string[];
metrics: string[];
- whisker_type: 'tukey' | 'min/max' | 'percentile';
+ whisker_type: BoxPlotQueryObjectWhiskerType;
percentiles?: [number, number];
};
}
+export type PostProcessingBoxplot =
+ | _PostProcessingBoxplot
+ | DefaultPostProcessing;
-export interface PostProcessingContribution {
+interface _PostProcessingContribution {
operation: 'contribution';
options?: {
orientation?: 'row' | 'column';
@@ -90,8 +99,11 @@ export interface PostProcessingContribution {
rename_columns?: string[];
};
}
+export type PostProcessingContribution =
+ | _PostProcessingContribution
+ | DefaultPostProcessing;
-export interface PostProcessingPivot {
+interface _PostProcessingPivot {
operation: 'pivot';
options: {
aggregates: Aggregates;
@@ -107,8 +119,9 @@ export interface PostProcessingPivot {
reset_index?: boolean;
};
}
+export type PostProcessingPivot = _PostProcessingPivot | DefaultPostProcessing;
-export interface PostProcessingProphet {
+interface _PostProcessingProphet {
operation: 'prophet';
options: {
time_grain: TimeGranularity;
@@ -119,8 +132,11 @@ export interface PostProcessingProphet {
daily_seasonality?: boolean | number;
};
}
+export type PostProcessingProphet =
+ | _PostProcessingProphet
+ | DefaultPostProcessing;
-export interface PostProcessingDiff {
+interface _PostProcessingDiff {
operation: 'diff';
options: {
columns: string[];
@@ -128,28 +144,31 @@ export interface PostProcessingDiff {
axis: PandasAxis;
};
}
+export type PostProcessingDiff = _PostProcessingDiff | DefaultPostProcessing;
-export interface PostProcessingRolling {
+interface _PostProcessingRolling {
operation: 'rolling';
options: {
rolling_type: RollingType;
window: number;
min_periods: number;
columns: string[];
- is_pivot_df?: boolean;
};
}
+export type PostProcessingRolling =
+ | _PostProcessingRolling
+ | DefaultPostProcessing;
-export interface PostProcessingCum {
+interface _PostProcessingCum {
operation: 'cum';
options: {
columns: string[];
operator: NumpyFunction;
- is_pivot_df?: boolean;
};
}
+export type PostProcessingCum = _PostProcessingCum | DefaultPostProcessing;
-export interface PostProcessingCompare {
+export interface _PostProcessingCompare {
operation: 'compare';
options: {
source_columns: string[];
@@ -158,26 +177,39 @@ export interface PostProcessingCompare {
drop_original_columns: boolean;
};
}
+export type PostProcessingCompare =
+ | _PostProcessingCompare
+ | DefaultPostProcessing;
-export interface PostProcessingSort {
+interface _PostProcessingSort {
operation: 'sort';
options: {
columns: Record<string, boolean>;
};
}
+export type PostProcessingSort = _PostProcessingSort | DefaultPostProcessing;
-export interface PostProcessingResample {
+interface _PostProcessingResample {
operation: 'resample';
options: {
method: string;
rule: string;
fill_value?: number | null;
- time_column: string;
- // If AdhocColumn doesn't have a label, it will be undefined.
- // todo: we have to give an explicit label for AdhocColumn.
- groupby_columns?: Array<string | undefined>;
};
}
+export type PostProcessingResample =
+ | _PostProcessingResample
+ | DefaultPostProcessing;
+
+interface _PostProcessingFlatten {
+ operation: 'flatten';
+ options?: {
+ reset_index?: boolean;
+ };
+}
+export type PostProcessingFlatten =
+ | _PostProcessingFlatten
+ | DefaultPostProcessing;
/**
* Parameters for chart data postprocessing.
@@ -194,7 +226,8 @@ export type PostProcessingRule =
| PostProcessingCum
| PostProcessingCompare
| PostProcessingSort
- | PostProcessingResample;
+ | PostProcessingResample
+ | PostProcessingFlatten;
export function isPostProcessingAggregation(
rule?: PostProcessingRule,
diff --git
a/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts
b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts
index be35734..d55cf46 100644
---
a/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts
+++
b/superset-frontend/plugins/plugin-chart-echarts/src/BigNumber/BigNumberWithTrendline/buildQuery.ts
@@ -18,11 +18,14 @@
*/
import {
buildQueryContext,
- DTTM_ALIAS,
PostProcessingResample,
QueryFormData,
} from '@superset-ui/core';
-import { rollingWindowOperator } from '@superset-ui/chart-controls';
+import {
+ flattenOperator,
+ rollingWindowOperator,
+ sortOperator,
+} from '@superset-ui/chart-controls';
const TIME_GRAIN_MAP: Record<string, string> = {
PT1S: 'S',
@@ -47,12 +50,10 @@ const TIME_GRAIN_MAP: Record<string, string> = {
export default function buildQuery(formData: QueryFormData) {
return buildQueryContext(formData, baseQueryObject => {
+ // todo: move into full advanced analysis section here
const rollingProc = rollingWindowOperator(formData, baseQueryObject);
- if (rollingProc) {
- rollingProc.options = { ...rollingProc.options, is_pivot_df: false };
- }
const { time_grain_sqla } = formData;
- let resampleProc: PostProcessingResample | undefined;
+ let resampleProc: PostProcessingResample;
if (rollingProc && time_grain_sqla) {
const rule = TIME_GRAIN_MAP[time_grain_sqla];
if (rule) {
@@ -62,7 +63,6 @@ export default function buildQuery(formData: QueryFormData) {
method: 'asfreq',
rule,
fill_value: null,
- time_column: DTTM_ALIAS,
},
};
}
@@ -72,16 +72,10 @@ export default function buildQuery(formData: QueryFormData)
{
...baseQueryObject,
is_timeseries: true,
post_processing: [
- {
- operation: 'sort',
- options: {
- columns: {
- [DTTM_ALIAS]: true,
- },
- },
- },
+ sortOperator(formData, baseQueryObject),
resampleProc,
rollingProc,
+ flattenOperator(formData, baseQueryObject),
],
},
];
diff --git
a/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/buildQuery.ts
b/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/buildQuery.ts
index bc6fb6a..b85feb1 100644
---
a/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/buildQuery.ts
+++
b/superset-frontend/plugins/plugin-chart-echarts/src/MixedTimeseries/buildQuery.ts
@@ -22,7 +22,7 @@ import {
QueryObject,
normalizeOrderBy,
} from '@superset-ui/core';
-import { pivotOperator } from '@superset-ui/chart-controls';
+import { flattenOperator, pivotOperator } from '@superset-ui/chart-controls';
export default function buildQuery(formData: QueryFormData) {
const {
@@ -66,6 +66,7 @@ export default function buildQuery(formData: QueryFormData) {
is_timeseries: true,
post_processing: [
pivotOperator(formData1, { ...baseQueryObject, is_timeseries: true }),
+ flattenOperator(formData1, { ...baseQueryObject, is_timeseries: true
}),
],
} as QueryObject;
return [normalizeOrderBy(queryObjectA)];
@@ -77,6 +78,7 @@ export default function buildQuery(formData: QueryFormData) {
is_timeseries: true,
post_processing: [
pivotOperator(formData2, { ...baseQueryObject, is_timeseries: true }),
+ flattenOperator(formData2, { ...baseQueryObject, is_timeseries: true
}),
],
} as QueryObject;
return [normalizeOrderBy(queryObjectB)];
diff --git
a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/buildQuery.ts
b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/buildQuery.ts
index 1571c1c..c4cdaa9 100644
---
a/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/buildQuery.ts
+++
b/superset-frontend/plugins/plugin-chart-echarts/src/Timeseries/buildQuery.ts
@@ -22,42 +22,54 @@ import {
ensureIsArray,
QueryFormData,
normalizeOrderBy,
- RollingType,
PostProcessingPivot,
} from '@superset-ui/core';
import {
rollingWindowOperator,
timeCompareOperator,
isValidTimeCompare,
- sortOperator,
pivotOperator,
resampleOperator,
contributionOperator,
prophetOperator,
+ timeComparePivotOperator,
+ flattenOperator,
} from '@superset-ui/chart-controls';
export default function buildQuery(formData: QueryFormData) {
const { x_axis, groupby } = formData;
const is_timeseries = x_axis === DTTM_ALIAS || !x_axis;
return buildQueryContext(formData, baseQueryObject => {
- const pivotOperatorInRuntime: PostProcessingPivot | undefined =
- pivotOperator(formData, {
- ...baseQueryObject,
- index: x_axis,
- is_timeseries,
- });
- if (
- pivotOperatorInRuntime &&
- Object.values(RollingType).includes(formData.rolling_type)
- ) {
- pivotOperatorInRuntime.options = {
- ...pivotOperatorInRuntime.options,
- ...{
- flatten_columns: false,
- reset_index: false,
- },
- };
- }
+ /* the `pivotOperatorInRuntime` determines how to pivot the dataframe
returned from the raw query.
+ 1. If it's a time compared query, there will return a pivoted dataframe
that append time compared metrics. for instance:
+
+ MAX(value) MAX(value)__1 year ago MIN(value)
MIN(value)__1 year ago
+ city LA LA LA
LA
+ __timestamp
+ 2015-01-01 568.0 671.0 5.0
6.0
+ 2015-02-01 407.0 649.0 4.0
3.0
+ 2015-03-01 318.0 465.0 0.0
3.0
+
+ 2. If it's a normal query, there will return a pivoted dataframe.
+
+ MAX(value) MIN(value)
+ city LA LA
+ __timestamp
+ 2015-01-01 568.0 5.0
+ 2015-02-01 407.0 4.0
+ 2015-03-01 318.0 0.0
+
+ */
+ const pivotOperatorInRuntime: PostProcessingPivot = isValidTimeCompare(
+ formData,
+ baseQueryObject,
+ )
+ ? timeComparePivotOperator(formData, baseQueryObject)
+ : pivotOperator(formData, {
+ ...baseQueryObject,
+ index: x_axis,
+ is_timeseries,
+ });
return [
{
@@ -70,13 +82,16 @@ export default function buildQuery(formData: QueryFormData)
{
time_offsets: isValidTimeCompare(formData, baseQueryObject)
? formData.time_compare
: [],
+ /* Note that:
+ 1. The resample, rolling, cum, timeCompare operators should be after
pivot.
+ 2. the flatOperator makes multiIndex Dataframe into flat Dataframe
+ */
post_processing: [
- resampleOperator(formData, baseQueryObject),
- timeCompareOperator(formData, baseQueryObject),
- sortOperator(formData, { ...baseQueryObject, is_timeseries: true }),
- // in order to be able to rolling in multiple series, must do pivot
before rollingOperator
pivotOperatorInRuntime,
rollingWindowOperator(formData, baseQueryObject),
+ timeCompareOperator(formData, baseQueryObject),
+ resampleOperator(formData, baseQueryObject),
+ flattenOperator(formData, baseQueryObject),
contributionOperator(formData, baseQueryObject),
prophetOperator(formData, baseQueryObject),
],
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 02e80ab..27ea1659 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -767,6 +767,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
"diff",
"compare",
"resample",
+ "flatten",
)
),
example="aggregate",
diff --git a/superset/common/query_context_processor.py
b/superset/common/query_context_processor.py
index 5772aec..7954f86 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -36,7 +36,11 @@ from superset.common.utils import dataframe_utils as df_utils
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.constants import CacheRegion
-from superset.exceptions import QueryObjectValidationError, SupersetException
+from superset.exceptions import (
+ InvalidPostProcessingError,
+ QueryObjectValidationError,
+ SupersetException,
+)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.utils import csv
@@ -196,7 +200,11 @@ class QueryContextProcessor:
query += ";\n\n".join(queries)
query += ";\n\n"
- df = query_object.exec_post_processing(df)
+ # Re-raising QueryObjectValidationError
+ try:
+ df = query_object.exec_post_processing(df)
+ except InvalidPostProcessingError as ex:
+ raise QueryObjectValidationError from ex
result.df = df
result.query = query
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 139dc27..b2e6fe1 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name
from __future__ import annotations
+import json
import logging
from datetime import datetime, timedelta
from pprint import pformat
@@ -27,6 +28,7 @@ from pandas import DataFrame
from superset.common.chart_data import ChartDataResultType
from superset.exceptions import (
+ InvalidPostProcessingError,
QueryClauseValidationException,
QueryObjectValidationError,
)
@@ -337,6 +339,10 @@ class QueryObject: # pylint:
disable=too-many-instance-attributes
}
return query_object_dict
+ def __repr__(self) -> str:
+ # we use `print` or `logging` output QueryObject
+ return json.dumps(self.to_dict(), sort_keys=True, default=str,)
+
def cache_key(self, **extra: Any) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
@@ -398,15 +404,15 @@ class QueryObject: # pylint:
disable=too-many-instance-attributes
:raises QueryObjectValidationError: If the post processing operation
is incorrect
"""
- logger.debug("post_processing: %s", pformat(self.post_processing))
+ logger.debug("post_processing: \n %s", pformat(self.post_processing))
for post_process in self.post_processing:
operation = post_process.get("operation")
if not operation:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("`operation` property of post processing object
undefined")
)
if not hasattr(pandas_postprocessing, operation):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_(
"Unsupported post processing operation: %(operation)s",
type=operation,
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 6ed3a0e..6b25904 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -190,6 +190,10 @@ class QueryObjectValidationError(SupersetException):
status = 400
+class InvalidPostProcessingError(SupersetException):
+ status = 400
+
+
class CacheLoadError(SupersetException):
status = 404
diff --git a/superset/utils/pandas_postprocessing/__init__.py
b/superset/utils/pandas_postprocessing/__init__.py
index 976e629..3d180bc 100644
--- a/superset/utils/pandas_postprocessing/__init__.py
+++ b/superset/utils/pandas_postprocessing/__init__.py
@@ -20,6 +20,7 @@ from superset.utils.pandas_postprocessing.compare import
compare
from superset.utils.pandas_postprocessing.contribution import contribution
from superset.utils.pandas_postprocessing.cum import cum
from superset.utils.pandas_postprocessing.diff import diff
+from superset.utils.pandas_postprocessing.flatten import flatten
from superset.utils.pandas_postprocessing.geography import (
geodetic_parse,
geohash_decode,
@@ -49,5 +50,6 @@ __all__ = [
"rolling",
"select",
"sort",
+ "flatten",
"_flatten_column_after_pivot",
]
diff --git a/superset/utils/pandas_postprocessing/aggregate.py
b/superset/utils/pandas_postprocessing/aggregate.py
index 2d6d396..a863d26 100644
--- a/superset/utils/pandas_postprocessing/aggregate.py
+++ b/superset/utils/pandas_postprocessing/aggregate.py
@@ -35,7 +35,7 @@ def aggregate(
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
diff --git a/superset/utils/pandas_postprocessing/boxplot.py
b/superset/utils/pandas_postprocessing/boxplot.py
index 9887507..4436af9 100644
--- a/superset/utils/pandas_postprocessing/boxplot.py
+++ b/superset/utils/pandas_postprocessing/boxplot.py
@@ -20,7 +20,7 @@ import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, Series
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing.aggregate import aggregate
@@ -84,7 +84,7 @@ def boxplot(
or not isinstance(percentiles[1], (int, float))
or percentiles[0] >= percentiles[1]
):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_(
"percentiles must be a list or tuple with two numeric
values, "
"of which the first is lower than the second value"
diff --git a/superset/utils/pandas_postprocessing/compare.py
b/superset/utils/pandas_postprocessing/compare.py
index 67f275e..18a66ce 100644
--- a/superset/utils/pandas_postprocessing/compare.py
+++ b/superset/utils/pandas_postprocessing/compare.py
@@ -21,7 +21,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import PandasPostprocessingCompare
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import TIME_COMPARISION
from superset.utils.pandas_postprocessing.utils import validate_column_args
@@ -31,7 +31,7 @@ def compare( # pylint: disable=too-many-arguments
df: DataFrame,
source_columns: List[str],
compare_columns: List[str],
- compare_type: Optional[PandasPostprocessingCompare],
+ compare_type: PandasPostprocessingCompare,
drop_original_columns: Optional[bool] = False,
precision: Optional[int] = 4,
) -> DataFrame:
@@ -46,31 +46,38 @@ def compare( # pylint: disable=too-many-arguments
compare columns.
:param precision: Round a change rate to a variable number of decimal
places.
:return: DataFrame with compared columns.
- :raises QueryObjectValidationError: If the request in incorrect.
+ :raises InvalidPostProcessingError: If the request in incorrect.
"""
if len(source_columns) != len(compare_columns):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("`compare_columns` must have the same length as
`source_columns`.")
)
if compare_type not in tuple(PandasPostprocessingCompare):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("`compare_type` must be `difference`, `percentage` or `ratio`")
)
if len(source_columns) == 0:
return df
for s_col, c_col in zip(source_columns, compare_columns):
+ s_df = df.loc[:, [s_col]]
+ s_df.rename(columns={s_col: "__intermediate"}, inplace=True)
+ c_df = df.loc[:, [c_col]]
+ c_df.rename(columns={c_col: "__intermediate"}, inplace=True)
if compare_type == PandasPostprocessingCompare.DIFF:
- diff_series = df[s_col] - df[c_col]
+ diff_df = c_df - s_df
elif compare_type == PandasPostprocessingCompare.PCT:
- diff_series = (
- ((df[s_col] - df[c_col]) /
df[c_col]).astype(float).round(precision)
- )
+ #
https://en.wikipedia.org/wiki/Relative_change_and_difference#Percentage_change
+ diff_df = ((c_df - s_df) / s_df).astype(float).round(precision)
else:
# compare_type == "ratio"
- diff_series = (df[s_col] /
df[c_col]).astype(float).round(precision)
- diff_df = diff_series.to_frame(
- name=TIME_COMPARISION.join([compare_type, s_col, c_col])
+ diff_df = (c_df / s_df).astype(float).round(precision)
+
+ diff_df.rename(
+ columns={
+ "__intermediate": TIME_COMPARISION.join([compare_type, s_col,
c_col])
+ },
+ inplace=True,
)
df = pd.concat([df, diff_df], axis=1)
diff --git a/superset/utils/pandas_postprocessing/contribution.py
b/superset/utils/pandas_postprocessing/contribution.py
index 7097ea3..2bfc6f4 100644
--- a/superset/utils/pandas_postprocessing/contribution.py
+++ b/superset/utils/pandas_postprocessing/contribution.py
@@ -20,7 +20,7 @@ from typing import List, Optional
from flask_babel import gettext as _
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing.utils import validate_column_args
@@ -55,7 +55,7 @@ def contribution(
numeric_columns = numeric_df.columns.tolist()
for col in columns:
if col not in numeric_columns:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_(
'Column "%(column)s" is not numeric or does not '
"exists in the query results.",
@@ -65,7 +65,7 @@ def contribution(
columns = columns or numeric_df.columns
rename_columns = rename_columns or columns
if len(rename_columns) != len(columns):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("`rename_columns` must have the same length as `columns`.")
)
# limit to selected columns
diff --git a/superset/utils/pandas_postprocessing/cum.py
b/superset/utils/pandas_postprocessing/cum.py
index c142b36..d2bd576 100644
--- a/superset/utils/pandas_postprocessing/cum.py
+++ b/superset/utils/pandas_postprocessing/cum.py
@@ -14,27 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Dict
from flask_babel import gettext as _
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
- _flatten_column_after_pivot,
ALLOWLIST_CUMULATIVE_FUNCTIONS,
validate_column_args,
)
@validate_column_args("columns")
-def cum(
- df: DataFrame,
- operator: str,
- columns: Optional[Dict[str, str]] = None,
- is_pivot_df: bool = False,
-) -> DataFrame:
+def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
@@ -45,29 +39,16 @@ def cum(
`y2` based on cumulative values calculated from `y`, leaving the
original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
- :param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with cumulated columns
"""
columns = columns or {}
- if is_pivot_df:
- df_cum = df
- else:
- df_cum = df[columns.keys()]
+ df_cum = df.loc[:, columns.keys()]
operation = "cum" + operator
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
- if is_pivot_df:
- df_cum = getattr(df_cum, operation)()
- agg_in_pivot_df =
df.columns.get_level_values(0).drop_duplicates().to_list()
- agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
- df_cum.columns = [
- _flatten_column_after_pivot(col, agg) for col in df_cum.columns
- ]
- df_cum.reset_index(level=0, inplace=True)
- else:
- df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
+ df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
return df_cum
diff --git a/superset/utils/pandas_postprocessing/diff.py
b/superset/utils/pandas_postprocessing/diff.py
index fd7c83b..0cead2d 100644
--- a/superset/utils/pandas_postprocessing/diff.py
+++ b/superset/utils/pandas_postprocessing/diff.py
@@ -44,7 +44,7 @@ def diff(
:param periods: periods to shift for calculating difference.
:param axis: 0 for row, 1 for column. default 0.
:return: DataFrame with diffed columns
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods, axis=axis)
diff --git a/superset/utils/pandas_postprocessing/flatten.py
b/superset/utils/pandas_postprocessing/flatten.py
new file mode 100644
index 0000000..a348801
--- /dev/null
+++ b/superset/utils/pandas_postprocessing/flatten.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+
+from superset.utils.pandas_postprocessing.utils import (
+ _is_multi_index_on_columns,
+ FLAT_COLUMN_SEPARATOR,
+)
+
+
+def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame:
+ """
+ Convert N-dimensional DataFrame to a flat DataFrame
+
+ :param df: N-dimensional DataFrame.
+ :param reset_index: Convert index to column when df.index isn't RangeIndex
+ :return: a flat DataFrame
+
+ Examples
+ -----------
+
+ Convert DatetimeIndex into columns.
+
+ >>> index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03",])
+ >>> index.name = "__timestamp"
+ >>> df = pd.DataFrame(index=index, data={"metric": [1, 2, 3]})
+ >>> df
+ metric
+ __timestamp
+ 2021-01-01 1
+ 2021-01-02 2
+ 2021-01-03 3
+ >>> df = flatten(df)
+ >>> df
+ __timestamp metric
+ 0 2021-01-01 1
+ 1 2021-01-02 2
+ 2 2021-01-03 3
+
+ Convert DatetimeIndex and MultipleIndex into columns
+
+ >>> iterables = [["foo", "bar"], ["one", "two"]]
+ >>> columns = pd.MultiIndex.from_product(iterables, names=["level1",
"level2"])
+ >>> df = pd.DataFrame(index=index, columns=columns, data=1)
+ >>> df
+ level1 foo bar
+ level2 one two one two
+ __timestamp
+ 2021-01-01 1 1 1 1
+ 2021-01-02 1 1 1 1
+ 2021-01-03 1 1 1 1
+ >>> flatten(df)
+ __timestamp foo, one foo, two bar, one bar, two
+ 0 2021-01-01 1 1 1 1
+ 1 2021-01-02 1 1 1 1
+ 2 2021-01-03 1 1 1 1
+ """
+ if _is_multi_index_on_columns(df):
+ # every cell should be converted to string
+ df.columns = [
+ FLAT_COLUMN_SEPARATOR.join([str(cell) for cell in series])
+ for series in df.columns.to_flat_index()
+ ]
+
+ if reset_index and not isinstance(df.index, pd.RangeIndex):
+ df = df.reset_index(level=0)
+ return df
diff --git a/superset/utils/pandas_postprocessing/geography.py
b/superset/utils/pandas_postprocessing/geography.py
index a1aae59..8ea75d2 100644
--- a/superset/utils/pandas_postprocessing/geography.py
+++ b/superset/utils/pandas_postprocessing/geography.py
@@ -21,7 +21,7 @@ from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import _append_columns
@@ -46,7 +46,7 @@ def geohash_decode(
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
)
except ValueError as ex:
- raise QueryObjectValidationError(_("Invalid geohash string")) from ex
+ raise InvalidPostProcessingError(_("Invalid geohash string")) from ex
def geohash_encode(
@@ -69,7 +69,7 @@ def geohash_encode(
)
return _append_columns(df, encode_df, {"geohash": geohash})
except ValueError as ex:
- raise QueryObjectValidationError(_("Invalid longitude/latitude")) from
ex
+ raise InvalidPostProcessingError(_("Invalid longitude/latitude")) from
ex
def geodetic_parse(
@@ -111,4 +111,4 @@ def geodetic_parse(
columns["altitude"] = altitude
return _append_columns(df, geodetic_df, columns)
except ValueError as ex:
- raise QueryObjectValidationError(_("Invalid geodetic string")) from ex
+ raise InvalidPostProcessingError(_("Invalid geodetic string")) from ex
diff --git a/superset/utils/pandas_postprocessing/pivot.py
b/superset/utils/pandas_postprocessing/pivot.py
index b9d70e9..829329e 100644
--- a/superset/utils/pandas_postprocessing/pivot.py
+++ b/superset/utils/pandas_postprocessing/pivot.py
@@ -20,7 +20,7 @@ from flask_babel import gettext as _
from pandas import DataFrame
from superset.constants import NULL_STRING, PandasAxis
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_flatten_column_after_pivot,
_get_aggregate_funcs,
@@ -64,14 +64,14 @@ def pivot( # pylint:
disable=too-many-arguments,too-many-locals
:param flatten_columns: Convert column names to strings
:param reset_index: Convert index to column
:return: A pivot table
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
if not index:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Pivot operation requires at least one index")
)
if not aggregates:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Pivot operation must include at least one aggregate")
)
diff --git a/superset/utils/pandas_postprocessing/prophet.py
b/superset/utils/pandas_postprocessing/prophet.py
index 3ade7f6..8a85e58 100644
--- a/superset/utils/pandas_postprocessing/prophet.py
+++ b/superset/utils/pandas_postprocessing/prophet.py
@@ -20,7 +20,7 @@ from typing import Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
@@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint:
disable=too-many-arguments
prophet_logger.setLevel(logging.CRITICAL)
prophet_logger.setLevel(logging.NOTSET)
except ModuleNotFoundError as ex:
- raise QueryObjectValidationError(_("`prophet` package not installed"))
from ex
+ raise InvalidPostProcessingError(_("`prophet` package not installed"))
from ex
model = Prophet(
interval_width=confidence_interval,
yearly_seasonality=yearly_seasonality,
@@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
index = index or DTTM_ALIAS
# validate inputs
if not time_grain:
- raise QueryObjectValidationError(_("Time grain missing"))
+ raise InvalidPostProcessingError(_("Time grain missing"))
if time_grain not in PROPHET_TIME_GRAIN_MAP:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
)
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
# check type at runtime due to marhsmallow schema not being able to handle
# union types
if not isinstance(periods, int) or periods < 0:
- raise QueryObjectValidationError(_("Periods must be a whole number"))
+ raise InvalidPostProcessingError(_("Periods must be a whole number"))
if not confidence_interval or confidence_interval <= 0 or
confidence_interval >= 1:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Confidence interval must be between 0 and 1 (exclusive)")
)
if index not in df.columns:
- raise QueryObjectValidationError(_("DataFrame must include temporal
column"))
+ raise InvalidPostProcessingError(_("DataFrame must include temporal
column"))
if len(df.columns) < 2:
- raise QueryObjectValidationError(_("DataFrame include at least one
series"))
+ raise InvalidPostProcessingError(_("DataFrame include at least one
series"))
target_df = DataFrame()
for column in [column for column in df.columns if column != index]:
diff --git a/superset/utils/pandas_postprocessing/resample.py
b/superset/utils/pandas_postprocessing/resample.py
index 54e67ac..a777672 100644
--- a/superset/utils/pandas_postprocessing/resample.py
+++ b/superset/utils/pandas_postprocessing/resample.py
@@ -14,48 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Tuple, Union
+from typing import Optional, Union
-from pandas import DataFrame
+import pandas as pd
+from flask_babel import gettext as _
-from superset.utils.pandas_postprocessing.utils import validate_column_args
+from superset.exceptions import InvalidPostProcessingError
-@validate_column_args("groupby_columns")
-def resample( # pylint: disable=too-many-arguments
- df: DataFrame,
+def resample(
+ df: pd.DataFrame,
rule: str,
method: str,
- time_column: str,
- groupby_columns: Optional[Tuple[Optional[str], ...]] = None,
fill_value: Optional[Union[float, int]] = None,
-) -> DataFrame:
+) -> pd.DataFrame:
"""
support upsampling in resample
:param df: DataFrame to resample.
:param rule: The offset string representing target conversion.
:param method: How to fill the NaN value after resample.
- :param time_column: existing columns in DataFrame.
- :param groupby_columns: columns except time_column in dataframe
:param fill_value: What values do fill missing.
:return: DataFrame after resample
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
+ if not isinstance(df.index, pd.DatetimeIndex):
+ raise InvalidPostProcessingError(_("Resample operation requires
DatetimeIndex"))
- def _upsampling(_df: DataFrame) -> DataFrame:
- _df = _df.set_index(time_column)
- if method == "asfreq" and fill_value is not None:
- return _df.resample(rule).asfreq(fill_value=fill_value)
- return getattr(_df.resample(rule), method)()
-
- if groupby_columns:
- df = (
- df.set_index(keys=list(groupby_columns))
- .groupby(by=list(groupby_columns))
- .apply(_upsampling)
- )
- df = df.reset_index().set_index(time_column).sort_index()
+ if method == "asfreq" and fill_value is not None:
+ _df = df.resample(rule).asfreq(fill_value=fill_value)
else:
- df = _upsampling(df)
- return df.reset_index()
+ _df = getattr(df.resample(rule), method)()
+ return _df
diff --git a/superset/utils/pandas_postprocessing/rolling.py
b/superset/utils/pandas_postprocessing/rolling.py
index f93b3da..885032e 100644
--- a/superset/utils/pandas_postprocessing/rolling.py
+++ b/superset/utils/pandas_postprocessing/rolling.py
@@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
- _flatten_column_after_pivot,
DENYLIST_ROLLING_FUNCTIONS,
validate_column_args,
)
@@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
rolling_type: str,
- columns: Optional[Dict[str, str]] = None,
+ columns: Dict[str, str],
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
- is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further
details:
@@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be
included
in the result set.
- :param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
- columns = columns or {}
- if is_pivot_df:
- df_rolling = df
- else:
- df_rolling = df[columns.keys()]
+ df_rolling = df.loc[:, columns.keys()]
+
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
- raise QueryObjectValidationError(_("Undefined window for rolling
operation"))
+ raise InvalidPostProcessingError(_("Undefined window for rolling
operation"))
if window == 0:
- raise QueryObjectValidationError(_("Window must be > 0"))
+ raise InvalidPostProcessingError(_("Window must be > 0"))
kwargs["window"] = window
if min_periods is not None:
@@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError as ex:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
@@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
)
) from ex
- if is_pivot_df:
- agg_in_pivot_df =
df.columns.get_level_values(0).drop_duplicates().to_list()
- agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
- df_rolling.columns = [
- _flatten_column_after_pivot(col, agg) for col in df_rolling.columns
- ]
- df_rolling.reset_index(level=0, inplace=True)
- else:
- df_rolling = _append_columns(df, df_rolling, columns)
+ df_rolling = _append_columns(df, df_rolling, columns)
if min_periods:
df_rolling = df_rolling[min_periods:]
diff --git a/superset/utils/pandas_postprocessing/select.py
b/superset/utils/pandas_postprocessing/select.py
index 209d502..59fe886 100644
--- a/superset/utils/pandas_postprocessing/select.py
+++ b/superset/utils/pandas_postprocessing/select.py
@@ -42,7 +42,7 @@ def select(
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
df_select = df.copy(deep=False)
if columns:
diff --git a/superset/utils/pandas_postprocessing/sort.py
b/superset/utils/pandas_postprocessing/sort.py
index fdf8f94..feacfb6 100644
--- a/superset/utils/pandas_postprocessing/sort.py
+++ b/superset/utils/pandas_postprocessing/sort.py
@@ -30,6 +30,6 @@ def sort(df: DataFrame, columns: Dict[str, bool]) ->
DataFrame:
:param columns: columns by by which to sort. The key specifies the column
name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
- :raises QueryObjectValidationError: If the request in incorrect
+ :raises InvalidPostProcessingError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()),
ascending=list(columns.values()))
diff --git a/superset/utils/pandas_postprocessing/utils.py
b/superset/utils/pandas_postprocessing/utils.py
index 7d26994..7aebe1e 100644
--- a/superset/utils/pandas_postprocessing/utils.py
+++ b/superset/utils/pandas_postprocessing/utils.py
@@ -18,10 +18,11 @@ from functools import partial
from typing import Any, Callable, Dict, Tuple, Union
import numpy as np
+import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg, Timestamp
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
NUMPY_FUNCTIONS = {
"average": np.average,
@@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
"P1W/1970-01-04T00:00:00Z": "W",
}
+FLAT_COLUMN_SEPARATOR = ", "
+
def _flatten_column_after_pivot(
column: Union[float, Timestamp, str, Tuple[str, ...]],
@@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
- return ", ".join([str(col) for col in column])
+ return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
+
+
+def _is_multi_index_on_columns(df: DataFrame) -> bool:
+ return isinstance(df.columns, pd.MultiIndex)
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
- if options.get("is_pivot_df"):
- # skip validation when pivot Dataframe
- return func(df, **options)
- columns = df.columns.tolist()
+ if _is_multi_index_on_columns(df):
+ # MultiIndex column validate first level
+ columns = df.columns.get_level_values(0)
+ else:
+ columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options.get(name) or []
):
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
@@ -152,14 +160,14 @@ def _get_aggregate_funcs(
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
@@ -168,7 +176,7 @@ def _get_aggregate_funcs(
else:
func = NUMPY_FUNCTIONS.get(operator)
if not func:
- raise QueryObjectValidationError(
+ raise InvalidPostProcessingError(
_("Invalid numpy function: %(operator)s",
operator=operator,)
)
options = agg_obj.get("options", {})
@@ -186,6 +194,8 @@ def _append_columns(
assign method, which overwrites the original column in `base_df` if the
column
already exists, and appends the column if the name is not defined.
+ Note that! this is a memory-intensive operation.
+
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
@@ -196,6 +206,10 @@ def _append_columns(
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
- return base_df.assign(
- **{target: append_df[source] for source, target in columns.items()}
- )
+ if all(key == value for key, value in columns.items()):
+ # make sure to return a new DataFrame instead of changing the
`base_df`.
+ _base_df = base_df.copy()
+ _base_df.loc[:, columns.keys()] = append_df
+ return _base_df
+ append_df = append_df.rename(columns=columns)
+ return pd.concat([base_df, append_df], axis="columns")
diff --git a/tests/common/query_context_generator.py
b/tests/common/query_context_generator.py
index 1f87c0c..d97b270 100644
--- a/tests/common/query_context_generator.py
+++ b/tests/common/query_context_generator.py
@@ -172,18 +172,21 @@ POSTPROCESSING_OPERATIONS = {
{
"operation": "aggregate",
"options": {
- "groupby": ["gender"],
+ "groupby": ["name"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
- "options": {"q": 25},
+ # todo: rename "interpolation" to "method" when we
updated
+ # numpy.
+ #
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html
+ "options": {"q": 25, "interpolation": "lower"},
},
"median": {"operator": "median", "column": "sum__num",},
},
},
},
- {"operation": "sort", "options": {"columns": {"q1": False, "gender":
True},},},
+ {"operation": "sort", "options": {"columns": {"q1": False, "name":
True},},},
]
}
diff --git a/tests/integration_tests/query_context_tests.py
b/tests/integration_tests/query_context_tests.py
index b2c9b98..b2f2818 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import datetime
import re
import time
from typing import Any, Dict
@@ -30,7 +29,7 @@ from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.extensions import cache_manager
-from superset.utils.core import AdhocMetricExpressionType, backend
+from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -91,8 +90,9 @@ class TestQueryContext(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
- table = self.get_table(name=table_name)
- payload = get_query_context(table_name, table.id)
+ payload = get_query_context(
+ query_name=table_name, add_postprocessing_operations=True,
+ )
payload["force"] = True
query_context = ChartDataQueryContextSchema().load(payload)
@@ -100,6 +100,10 @@ class TestQueryContext(SupersetTestCase):
query_cache_key = query_context.query_cache_key(query_object)
response = query_context.get_payload(cache_query_context=True)
+ # MUST BE a successful query
+ query_dump = response["queries"][0]
+ assert query_dump["status"] == QueryStatus.SUCCESS
+
cache_key = response["cache_key"]
assert cache_key is not None
diff --git a/tests/unit_tests/pandas_postprocessing/test_boxplot.py
b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
index 247aba0..9252b0d 100644
--- a/tests/unit_tests/pandas_postprocessing/test_boxplot.py
+++ b/tests/unit_tests/pandas_postprocessing/test_boxplot.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing import boxplot
from tests.unit_tests.fixtures.dataframes import names_df
@@ -90,7 +90,7 @@ def test_boxplot_percentile():
def test_boxplot_percentile_incorrect_params():
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -98,7 +98,7 @@ def test_boxplot_percentile_incorrect_params():
metrics=["cars"],
)
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -107,7 +107,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[10],
)
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -116,7 +116,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[90, 10],
)
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
diff --git a/tests/unit_tests/pandas_postprocessing/test_compare.py
b/tests/unit_tests/pandas_postprocessing/test_compare.py
index d9213ca..970fa42 100644
--- a/tests/unit_tests/pandas_postprocessing/test_compare.py
+++ b/tests/unit_tests/pandas_postprocessing/test_compare.py
@@ -14,49 +14,220 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pandas as pd
-from superset.utils.pandas_postprocessing import compare
-from tests.unit_tests.fixtures.dataframes import timeseries_df2
-from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+from superset.constants import PandasPostprocessingCompare as PPC
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+from tests.unit_tests.fixtures.dataframes import multiple_metrics_df,
timeseries_df2
-def test_compare():
+def test_compare_should_not_side_effect():
+ _timeseries_df2 = timeseries_df2.copy()
+ pp.compare(
+ df=_timeseries_df2,
+ source_columns=["y"],
+ compare_columns=["z"],
+ compare_type=PPC.DIFF,
+ )
+ assert _timeseries_df2.equals(timeseries_df2)
+
+
+def test_compare_diff():
# `difference` comparison
- post_df = compare(
+ post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
- compare_type="difference",
+ compare_type=PPC.DIFF,
+ )
+ """
+ label y z difference__y__z
+ 2019-01-01 x 2.0 2.0 0.0
+ 2019-01-02 y 2.0 4.0 2.0
+ 2019-01-05 z 2.0 10.0 8.0
+ 2019-01-07 q 2.0 8.0 6.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "difference__y__z": [0.0, 2.0, 8.0, 6.0],
+ },
+ )
)
- assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
- assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0,
-6.0]
# drop original columns
- post_df = compare(
+ post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
- compare_type="difference",
+ compare_type=PPC.DIFF,
drop_original_columns=True,
)
- assert post_df.columns.tolist() == ["label", "difference__y__z"]
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "difference__y__z": [0.0, 2.0, 8.0, 6.0],
+ },
+ )
+ )
+
+def test_compare_percentage():
# `percentage` comparison
- post_df = compare(
+ post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
- compare_type="percentage",
+ compare_type=PPC.PCT,
+ )
+ """
+ label y z percentage__y__z
+ 2019-01-01 x 2.0 2.0 0.0
+ 2019-01-02 y 2.0 4.0 1.0
+ 2019-01-05 z 2.0 10.0 4.0
+ 2019-01-07 q 2.0 8.0 3.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "percentage__y__z": [0.0, 1.0, 4.0, 3.0],
+ },
+ )
)
- assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
- assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8,
-0.75]
+
+def test_compare_ratio():
# `ratio` comparison
- post_df = compare(
+ post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
- compare_type="ratio",
+ compare_type=PPC.RAT,
+ )
+ """
+ label y z ratio__y__z
+ 2019-01-01 x 2.0 2.0 1.0
+ 2019-01-02 y 2.0 4.0 2.0
+ 2019-01-05 z 2.0 10.0 5.0
+ 2019-01-07 q 2.0 8.0 4.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=timeseries_df2.index,
+ data={
+ "label": ["x", "y", "z", "q"],
+ "y": [2.0, 2.0, 2.0, 2.0],
+ "z": [2.0, 4.0, 10.0, 8.0],
+ "ratio__y__z": [1.0, 2.0, 5.0, 4.0],
+ },
+ )
+ )
+
+
+def test_compare_multi_index_column():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
+ columns = pd.MultiIndex.from_product(iterables, names=[None, "level1",
"level2"])
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+ """
+ m1 m2
+ level1 a b a b
+ level2 x y x y x y x y
+ __timestamp
+ 2021-01-01 1 1 1 1 1 1 1 1
+ 2021-01-02 1 1 1 1 1 1 1 1
+ 2021-01-03 1 1 1 1 1 1 1 1
+ """
+ post_df = pp.compare(
+ df,
+ source_columns=["m1"],
+ compare_columns=["m2"],
+ compare_type=PPC.DIFF,
+ drop_original_columns=True,
+ )
+ flat_df = pp.flatten(post_df)
+ """
+ __timestamp difference__m1__m2, a, x difference__m1__m2, a, y
difference__m1__m2, b, x difference__m1__m2, b, y
+ 0 2021-01-01 0 0
0 0
+ 1 2021-01-02 0 0
0 0
+ 2 2021-01-03 0 0
0 0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
+ ["2021-01-01", "2021-01-02", "2021-01-03"]
+ ),
+ "difference__m1__m2, a, x": [0, 0, 0],
+ "difference__m1__m2, a, y": [0, 0, 0],
+ "difference__m1__m2, b, x": [0, 0, 0],
+ "difference__m1__m2, b, y": [0, 0, 0],
+ }
+ )
+ )
+
+
+def test_compare_after_pivot():
+ pivot_df = pp.pivot(
+ df=multiple_metrics_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={
+ "sum_metric": {"operator": "sum"},
+ "count_metric": {"operator": "sum"},
+ },
+ flatten_columns=False,
+ reset_index=False,
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ compared_df = pp.compare(
+ pivot_df,
+ source_columns=["count_metric"],
+ compare_columns=["sum_metric"],
+ compare_type=PPC.DIFF,
+ drop_original_columns=True,
+ )
+ """
+ difference__count_metric__sum_metric
+ country UK US
+ dttm
+ 2019-01-01 4 4
+ 2019-01-02 4 4
+ """
+ flat_df = pp.flatten(compared_df)
+ """
+ dttm difference__count_metric__sum_metric, UK
difference__count_metric__sum_metric, US
+ 0 2019-01-01 4
4
+ 1 2019-01-02 4
4
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(
+ ["difference__count_metric__sum_metric", "UK"]
+ ): [4, 4],
+ FLAT_COLUMN_SEPARATOR.join(
+ ["difference__count_metric__sum_metric", "US"]
+ ): [4, 4],
+ }
+ )
)
- assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
- assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]
diff --git a/tests/unit_tests/pandas_postprocessing/test_contribution.py
b/tests/unit_tests/pandas_postprocessing/test_contribution.py
index 9d2df76..a385514 100644
--- a/tests/unit_tests/pandas_postprocessing/test_contribution.py
+++ b/tests/unit_tests/pandas_postprocessing/test_contribution.py
@@ -22,7 +22,7 @@ from numpy import nan
from numpy.testing import assert_array_equal
from pandas import DataFrame
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS,
PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing import contribution
@@ -40,10 +40,10 @@ def test_contribution():
"c": [nan, nan, nan],
}
)
- with pytest.raises(QueryObjectValidationError, match="not numeric"):
+ with pytest.raises(InvalidPostProcessingError, match="not numeric"):
contribution(df, columns=[DTTM_ALIAS])
- with pytest.raises(QueryObjectValidationError, match="same length"):
+ with pytest.raises(InvalidPostProcessingError, match="same length"):
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
# cell contribution across row
diff --git a/tests/unit_tests/pandas_postprocessing/test_cum.py
b/tests/unit_tests/pandas_postprocessing/test_cum.py
index b4b8fad..6cc5da2 100644
--- a/tests/unit_tests/pandas_postprocessing/test_cum.py
+++ b/tests/unit_tests/pandas_postprocessing/test_cum.py
@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pandas as pd
import pytest
-from pandas import to_datetime
-from superset.exceptions import QueryObjectValidationError
-from superset.utils.pandas_postprocessing import cum, pivot
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@@ -27,33 +28,41 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+def test_cum_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.cum(
+ df=timeseries_df, columns={"y": "y2"}, operator="sum",
+ )
+ assert _timeseries_df.equals(timeseries_df)
+
+
def test_cum():
# create new column (cumsum)
- post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
+ post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
assert post_df.columns.tolist() == ["label", "y", "y2"]
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
# overwrite column (cumprod)
- post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
+ post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
# overwrite column (cummin)
- post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
+ post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
# invalid operator
- with pytest.raises(QueryObjectValidationError):
- cum(
+ with pytest.raises(InvalidPostProcessingError):
+ pp.cum(
df=timeseries_df, columns={"y": "y"}, operator="abc",
)
-def test_cum_with_pivot_df_and_single_metric():
- pivot_df = pivot(
+def test_cum_after_pivot_with_single_metric():
+ pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@@ -61,19 +70,40 @@ def test_cum_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
- cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
- # dttm UK US
- # 0 2019-01-01 5 6
- # 1 2019-01-02 12 14
- assert cum_df["UK"].to_list() == [5.0, 12.0]
- assert cum_df["US"].to_list() == [6.0, 14.0]
- assert (
- cum_df["dttm"].to_list() == to_datetime(["2019-01-01",
"2019-01-02"]).to_list()
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 7 8
+ """
+ cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric":
"sum_metric"})
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 12 14
+ """
+ cum_and_flat_df = pp.flatten(cum_df)
+ """
+ dttm sum_metric, UK sum_metric, US
+ 0 2019-01-01 5 6
+ 1 2019-01-02 12 14
+ """
+ assert cum_and_flat_df.equals(
+ pd.DataFrame(
+ {
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
+ }
+ )
)
-def test_cum_with_pivot_df_and_multiple_metrics():
- pivot_df = pivot(
+def test_cum_after_pivot_with_multiple_metrics():
+ pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@@ -84,14 +114,39 @@ def test_cum_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
- cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
- # dttm count_metric, UK count_metric, US sum_metric, UK
sum_metric, US
- # 0 2019-01-01 1 2 5
6
- # 1 2019-01-02 4 6 12
14
- assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
- assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
- assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
- assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
- assert (
- cum_df["dttm"].to_list() == to_datetime(["2019-01-01",
"2019-01-02"]).to_list()
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ cum_df = pp.cum(
+ df=pivot_df,
+ operator="sum",
+ columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
+ )
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 4 6 12 14
+ """
+ flat_df = pp.flatten(cum_df)
+ """
+ dttm count_metric, UK count_metric, US sum_metric, UK
sum_metric, US
+ 0 2019-01-01 1 2 5
6
+ 1 2019-01-02 4 6 12
14
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ {
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
+ }
+ )
)
diff --git a/tests/unit_tests/pandas_postprocessing/test_diff.py
b/tests/unit_tests/pandas_postprocessing/test_diff.py
index abade20..a491d6c 100644
--- a/tests/unit_tests/pandas_postprocessing/test_diff.py
+++ b/tests/unit_tests/pandas_postprocessing/test_diff.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import diff
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@@ -39,7 +39,7 @@ def test_diff():
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
# invalid column reference
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
diff(
df=timeseries_df, columns={"abc": "abc"},
)
diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py
b/tests/unit_tests/pandas_postprocessing/test_flatten.py
new file mode 100644
index 0000000..01a180b
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pandas as pd
+
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
+
+
+def test_flat_should_not_change():
+ df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],})
+
+ assert pp.flatten(df).equals(df)
+
+
+def test_flat_should_not_reset_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
+
+ assert pp.flatten(df, reset_index=False).equals(df)
+
+
+def test_flat_should_flat_datetime_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
+
+ assert pp.flatten(df).equals(
+ pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5,
6],})
+ )
+
+
+def test_flat_should_flat_multiple_index():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ iterables = [["foo", "bar"], [1, "two"]]
+ columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+
+ assert pp.flatten(df).equals(
+ pd.DataFrame(
+ {
+ "__timestamp": index,
+ FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
+ FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
+ }
+ )
+ )
diff --git a/tests/unit_tests/pandas_postprocessing/test_pivot.py
b/tests/unit_tests/pandas_postprocessing/test_pivot.py
index 55779e3..e775df4 100644
--- a/tests/unit_tests/pandas_postprocessing/test_pivot.py
+++ b/tests/unit_tests/pandas_postprocessing/test_pivot.py
@@ -19,7 +19,7 @@ import numpy as np
import pytest
from pandas import DataFrame, Timestamp, to_datetime
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot,
pivot
from tests.unit_tests.fixtures.dataframes import categories_df,
single_metric_df
from tests.unit_tests.pandas_postprocessing.utils import (
@@ -172,7 +172,7 @@ def test_pivot_exceptions():
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
# invalid index reference
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["abc"],
@@ -181,7 +181,7 @@ def test_pivot_exceptions():
)
# invalid column reference
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["dept"],
@@ -190,7 +190,7 @@ def test_pivot_exceptions():
)
# invalid aggregate options
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["name"],
diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py
b/tests/unit_tests/pandas_postprocessing/test_prophet.py
index ce5c45b..f341a5e 100644
--- a/tests/unit_tests/pandas_postprocessing/test_prophet.py
+++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py
@@ -19,7 +19,7 @@ from importlib.util import find_spec
import pytest
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing import prophet
from tests.unit_tests.fixtures.dataframes import prophet_df
@@ -75,40 +75,40 @@ def test_prophet_valid_zero_periods():
def test_prophet_import():
dynamic_module = find_spec("prophet")
if dynamic_module is None:
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(df=prophet_df, time_grain="P1M", periods=3,
confidence_interval=0.9)
def test_prophet_missing_temporal_column():
df = prophet_df.drop(DTTM_ALIAS, axis=1)
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval():
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3,
confidence_interval=0.0,
)
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3,
confidence_interval=1.0,
)
def test_prophet_incorrect_periods():
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=-1,
confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain():
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="yearly", periods=10,
confidence_interval=0.8,
)
diff --git a/tests/unit_tests/pandas_postprocessing/test_resample.py
b/tests/unit_tests/pandas_postprocessing/test_resample.py
index 872f2ed..bd3a36e 100644
--- a/tests/unit_tests/pandas_postprocessing/test_resample.py
+++ b/tests/unit_tests/pandas_postprocessing/test_resample.py
@@ -14,45 +14,80 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pandas as pd
import pytest
-from pandas import DataFrame, to_datetime
-from superset.exceptions import QueryObjectValidationError
-from superset.utils.pandas_postprocessing import resample
-from tests.unit_tests.fixtures.dataframes import timeseries_df
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
-def test_resample():
- df = timeseries_df.copy()
- df.index.name = "time_column"
- df.reset_index(inplace=True)
+def test_resample_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.resample(df=_timeseries_df, rule="1D", method="ffill")
+ assert _timeseries_df.equals(timeseries_df)
+
- post_df = resample(df=df, rule="1D", method="ffill",
time_column="time_column",)
- assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
+def test_resample():
+ post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
+ """
+ label y
+ 2019-01-01 x 1.0
+ 2019-01-02 y 2.0
+ 2019-01-03 y 2.0
+ 2019-01-04 y 2.0
+ 2019-01-05 z 3.0
+ 2019-01-06 z 3.0
+ 2019-01-07 q 4.0
+ """
+ assert post_df.equals(
+ pd.DataFrame(
+ index=pd.to_datetime(
+ [
+ "2019-01-01",
+ "2019-01-02",
+ "2019-01-03",
+ "2019-01-04",
+ "2019-01-05",
+ "2019-01-06",
+ "2019-01-07",
+ ]
+ ),
+ data={
+ "label": ["x", "y", "y", "y", "z", "z", "q"],
+ "y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
+ },
+ )
+ )
- assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
- post_df = resample(
- df=df, rule="1D", method="asfreq", time_column="time_column",
fill_value=0,
+def test_resample_zero_fill():
+ post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq",
fill_value=0)
+ assert post_df.equals(
+ pd.DataFrame(
+ index=pd.to_datetime(
+ [
+ "2019-01-01",
+ "2019-01-02",
+ "2019-01-03",
+ "2019-01-04",
+ "2019-01-05",
+ "2019-01-06",
+ "2019-01-07",
+ ]
+ ),
+ data={
+ "label": ["x", "y", 0, 0, "z", 0, "q"],
+ "y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
+ },
+ )
)
- assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
- assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
-def test_resample_with_groupby():
- """
-The Dataframe contains a timestamp column, a string column and a numeric
column.
-__timestamp city val
-0 2022-01-13 Chicago 6.0
-1 2022-01-13 LA 5.0
-2 2022-01-13 NY 4.0
-3 2022-01-11 Chicago 3.0
-4 2022-01-11 LA 2.0
-5 2022-01-11 NY 1.0
- """
- df = DataFrame(
- {
- "__timestamp": to_datetime(
+def test_resample_after_pivot():
+ df = pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
[
"2022-01-13",
"2022-01-13",
@@ -66,42 +101,53 @@ __timestamp city val
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
}
)
- post_df = resample(
+ pivot_df = pp.pivot(
df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city",),
+ index=["__timestamp"],
+ columns=["city"],
+ aggregates={"val": {"operator": "sum"},},
+ flatten_columns=False,
+ reset_index=False,
)
- assert list(post_df.columns) == [
- "__timestamp",
- "city",
- "val",
- ]
- assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
- ["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
+ """
+ val
+ city Chicago LA NY
+ __timestamp
+ 2022-01-11 3.0 2.0 1.0
+ 2022-01-13 6.0 5.0 4.0
+ """
+ resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq",
fill_value=0,)
+ """
+ val
+ city Chicago LA NY
+ __timestamp
+ 2022-01-11 3.0 2.0 1.0
+ 2022-01-12 0.0 0.0 0.0
+ 2022-01-13 6.0 5.0 4.0
+ """
+ flat_df = pp.flatten(resample_df)
+ """
+ __timestamp val, Chicago val, LA val, NY
+ 0 2022-01-11 3.0 2.0 1.0
+ 1 2022-01-12 0.0 0.0 0.0
+ 2 2022-01-13 6.0 5.0 4.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "__timestamp": pd.to_datetime(
+ ["2022-01-11", "2022-01-12", "2022-01-13"]
+ ),
+ "val, Chicago": [3.0, 0, 6.0],
+ "val, LA": [2.0, 0, 5.0],
+ "val, NY": [1.0, 0, 4.0],
+ }
+ )
)
- assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
- # should raise error when get a non-existent column
- with pytest.raises(QueryObjectValidationError):
- resample(
- df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city", "unkonw_column",),
- )
- # should raise error when get a None value in groupby list
- with pytest.raises(QueryObjectValidationError):
- resample(
- df=df,
- rule="1D",
- method="asfreq",
- fill_value=0,
- time_column="__timestamp",
- groupby_columns=("city", None,),
+def test_resample_should_raise_ex():
+ with pytest.raises(InvalidPostProcessingError):
+ pp.resample(
+ df=categories_df, rule="1D", method="asfreq",
)
diff --git a/tests/unit_tests/pandas_postprocessing/test_rolling.py
b/tests/unit_tests/pandas_postprocessing/test_rolling.py
index 227b03a..616e4f5 100644
--- a/tests/unit_tests/pandas_postprocessing/test_rolling.py
+++ b/tests/unit_tests/pandas_postprocessing/test_rolling.py
@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pandas as pd
import pytest
-from pandas import to_datetime
-from superset.exceptions import QueryObjectValidationError
-from superset.utils.pandas_postprocessing import pivot, rolling
+from superset.exceptions import InvalidPostProcessingError
+from superset.utils import pandas_postprocessing as pp
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@@ -27,9 +28,21 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
+def test_rolling_should_not_side_effect():
+ _timeseries_df = timeseries_df.copy()
+ pp.rolling(
+ df=timeseries_df,
+ columns={"y": "y"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+ assert _timeseries_df.equals(timeseries_df)
+
+
def test_rolling():
# sum rolling type
- post_df = rolling(
+ post_df = pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
@@ -41,7 +54,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
# mean rolling type with alias
- post_df = rolling(
+ post_df = pp.rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
@@ -52,7 +65,7 @@ def test_rolling():
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
# count rolling type
- post_df = rolling(
+ post_df = pp.rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
@@ -63,7 +76,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
# quantile rolling type
- post_df = rolling(
+ post_df = pp.rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
@@ -75,14 +88,14 @@ def test_rolling():
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
# incorrect rolling type
- with pytest.raises(QueryObjectValidationError):
- rolling(
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rolling(
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
)
# incorrect rolling type options
- with pytest.raises(QueryObjectValidationError):
- rolling(
+ with pytest.raises(InvalidPostProcessingError):
+ pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
@@ -91,8 +104,8 @@ def test_rolling():
)
-def test_rolling_with_pivot_df_and_single_metric():
- pivot_df = pivot(
+def test_rolling_should_empty_df():
+ pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@@ -100,27 +113,65 @@ def test_rolling_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
- rolling_df = rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=0,
is_pivot_df=True,
- )
- # dttm UK US
- # 0 2019-01-01 5 6
- # 1 2019-01-02 12 14
- assert rolling_df["UK"].to_list() == [5.0, 12.0]
- assert rolling_df["US"].to_list() == [6.0, 14.0]
- assert (
- rolling_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ rolling_type="sum",
+ window=2,
+ min_periods=2,
+ columns={"sum_metric": "sum_metric"},
)
+ assert rolling_df.empty is True
- rolling_df = rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=2,
is_pivot_df=True,
+
+def test_rolling_after_pivot_with_single_metric():
+ pivot_df = pp.pivot(
+ df=single_metric_df,
+ index=["dttm"],
+ columns=["country"],
+ aggregates={"sum_metric": {"operator": "sum"}},
+ flatten_columns=False,
+ reset_index=False,
+ )
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5 6
+ 2019-01-02 7 8
+ """
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ columns={"sum_metric": "sum_metric"},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
+ )
+ """
+ sum_metric
+ country UK US
+ dttm
+ 2019-01-01 5.0 6.0
+ 2019-01-02 12.0 14.0
+ """
+ flat_df = pp.flatten(rolling_df)
+ """
+ dttm sum_metric, UK sum_metric, US
+ 0 2019-01-01 5.0 6.0
+ 1 2019-01-02 12.0 14.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
+ }
+ )
)
- assert rolling_df.empty is True
-def test_rolling_with_pivot_df_and_multiple_metrics():
- pivot_df = pivot(
+def test_rolling_after_pivot_with_multiple_metrics():
+ pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@@ -131,17 +182,41 @@ def test_rolling_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
- rolling_df = rolling(
- df=pivot_df, rolling_type="sum", window=2, min_periods=0,
is_pivot_df=True,
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1 2 5 6
+ 2019-01-02 3 4 7 8
+ """
+ rolling_df = pp.rolling(
+ df=pivot_df,
+ columns={"count_metric": "count_metric", "sum_metric": "sum_metric",},
+ rolling_type="sum",
+ window=2,
+ min_periods=0,
)
- # dttm count_metric, UK count_metric, US sum_metric, UK
sum_metric, US
- # 0 2019-01-01 1.0 2.0 5.0
6.0
- # 1 2019-01-02 4.0 6.0 12.0
14.0
- assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
- assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
- assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
- assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
- assert (
- rolling_df["dttm"].to_list()
- == to_datetime(["2019-01-01", "2019-01-02",]).to_list()
+ """
+ count_metric sum_metric
+ country UK US UK US
+ dttm
+ 2019-01-01 1.0 2.0 5.0 6.0
+ 2019-01-02 4.0 6.0 12.0 14.0
+ """
+ flat_df = pp.flatten(rolling_df)
+ """
+ dttm count_metric, UK count_metric, US sum_metric, UK
sum_metric, US
+ 0 2019-01-01 1.0 2.0 5.0
6.0
+ 1 2019-01-02 4.0 6.0 12.0
14.0
+ """
+ assert flat_df.equals(
+ pd.DataFrame(
+ data={
+ "dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
+ FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
+ FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
+ }
+ )
)
diff --git a/tests/unit_tests/pandas_postprocessing/test_select.py
b/tests/unit_tests/pandas_postprocessing/test_select.py
index aac644d..2ba126f 100644
--- a/tests/unit_tests/pandas_postprocessing/test_select.py
+++ b/tests/unit_tests/pandas_postprocessing/test_select.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.select import select
from tests.unit_tests.fixtures.dataframes import timeseries_df
@@ -47,9 +47,9 @@ def test_select():
assert post_df.columns.tolist() == ["y1"]
# invalid columns
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
# select renamed column by new name
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["label_new"], rename={"label":
"label_new"})
diff --git a/tests/unit_tests/pandas_postprocessing/test_sort.py
b/tests/unit_tests/pandas_postprocessing/test_sort.py
index 43daa9c..c489381 100644
--- a/tests/unit_tests/pandas_postprocessing/test_sort.py
+++ b/tests/unit_tests/pandas_postprocessing/test_sort.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import sort
from tests.unit_tests.fixtures.dataframes import categories_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@@ -26,5 +26,5 @@ def test_sort():
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
assert series_to_list(df["asc_idx"])[1] == 96
- with pytest.raises(QueryObjectValidationError):
+ with pytest.raises(InvalidPostProcessingError):
sort(df=df, columns={"abc": True})