This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8ca49d4e6f07efd00c745178d66f117507a53c85
Author: Zef Lin <[email protected]>
AuthorDate: Mon Sep 18 11:30:52 2023 -0700

    fix(chart): Supporting custom SQL as temporal x-axis column with filter 
(#25126)
    
    Co-authored-by: Kamil Gabryjelski <[email protected]>
---
 .../src/explore/actions/exploreActions.test.js     | 13 ++---
 .../src/explore/reducers/exploreReducer.js         | 18 +++---
 superset/common/query_context_factory.py           |  7 ++-
 superset/connectors/sqla/models.py                 |  2 +
 tests/integration_tests/charts/data/api_tests.py   | 66 ++++++++++++++++++++++
 5 files changed, 90 insertions(+), 16 deletions(-)

diff --git a/superset-frontend/src/explore/actions/exploreActions.test.js 
b/superset-frontend/src/explore/actions/exploreActions.test.js
index cdf76bdcc6..9dd5375680 100644
--- a/superset-frontend/src/explore/actions/exploreActions.test.js
+++ b/superset-frontend/src/explore/actions/exploreActions.test.js
@@ -22,20 +22,19 @@ import exploreReducer from 
'src/explore/reducers/exploreReducer';
 import * as actions from 'src/explore/actions/exploreActions';
 
 describe('reducers', () => {
-  it('sets correct control value given an arbitrary key and value', () => {
+  it('Does not set a control value if control does not exist', () => {
     const newState = exploreReducer(
       defaultState,
       actions.setControlValue('NEW_FIELD', 'x', []),
     );
-    expect(newState.controls.NEW_FIELD.value).toBe('x');
-    expect(newState.form_data.NEW_FIELD).toBe('x');
+    expect(newState.controls.NEW_FIELD).toBeUndefined();
   });
-  it('setControlValue works as expected with a checkbox', () => {
+  it('setControlValue works as expected with a Select control', () => {
     const newState = exploreReducer(
       defaultState,
-      actions.setControlValue('show_legend', true, []),
+      actions.setControlValue('y_axis_format', '$,.2f', []),
     );
-    expect(newState.controls.show_legend.value).toBe(true);
-    expect(newState.form_data.show_legend).toBe(true);
+    expect(newState.controls.y_axis_format.value).toBe('$,.2f');
+    expect(newState.form_data.y_axis_format).toBe('$,.2f');
   });
 });
diff --git a/superset-frontend/src/explore/reducers/exploreReducer.js 
b/superset-frontend/src/explore/reducers/exploreReducer.js
index ac451ade3d..d5565a0dad 100644
--- a/superset-frontend/src/explore/reducers/exploreReducer.js
+++ b/superset-frontend/src/explore/reducers/exploreReducer.js
@@ -112,7 +112,7 @@ export default function exploreReducer(state = {}, action) {
       const vizType = new_form_data.viz_type;
 
       // if the controlName is metrics, and the metric column name is updated,
-      // need to update column config as well to keep the previou config.
+      // need to update column config as well to keep the previous config.
       if (controlName === 'metrics' && old_metrics_data && new_column_config) {
         value.forEach((item, index) => {
           if (
@@ -129,11 +129,11 @@ export default function exploreReducer(state = {}, 
action) {
       }
 
       // Use the processed control config (with overrides and everything)
-      // if `controlName` does not existing in current controls,
+      // if `controlName` does not exist in current controls,
       const controlConfig =
         state.controls[action.controlName] ||
         getControlConfig(action.controlName, vizType) ||
-        {};
+        null;
 
       // will call validators again
       const control = {
@@ -149,7 +149,7 @@ export default function exploreReducer(state = {}, action) {
         ...state,
         controls: {
           ...state.controls,
-          [controlName]: control,
+          ...(controlConfig && { [controlName]: control }),
           ...(controlName === 'metrics' && { column_config }),
         },
       };
@@ -196,10 +196,12 @@ export default function exploreReducer(state = {}, 
action) {
         triggerRender: control.renderTrigger && !hasErrors,
         controls: {
           ...currentControlsState,
-          [action.controlName]: {
-            ...control,
-            validationErrors: errors,
-          },
+          ...(controlConfig && {
+            [action.controlName]: {
+              ...control,
+              validationErrors: errors,
+            },
+          }),
           ...rerenderedControls,
         },
       };
diff --git a/superset/common/query_context_factory.py 
b/superset/common/query_context_factory.py
index a6fe549894..e4680ed5ed 100644
--- a/superset/common/query_context_factory.py
+++ b/superset/common/query_context_factory.py
@@ -26,7 +26,7 @@ from superset.common.query_object_factory import 
QueryObjectFactory
 from superset.daos.chart import ChartDAO
 from superset.daos.datasource import DatasourceDAO
 from superset.models.slice import Slice
-from superset.utils.core import DatasourceDict, DatasourceType
+from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
 
 if TYPE_CHECKING:
     from superset.connectors.base.models import BaseDatasource
@@ -129,6 +129,8 @@ class QueryContextFactory:  # pylint: 
disable=too-few-public-methods
 
         if granularity := query_object.granularity:
             filter_to_remove = None
+            if is_adhoc_column(x_axis):  # type: ignore
+                x_axis = x_axis.get("sqlExpression")
             if x_axis and x_axis in temporal_columns:
                 filter_to_remove = x_axis
                 x_axis_column = next(
@@ -176,6 +178,9 @@ class QueryContextFactory:  # pylint: 
disable=too-few-public-methods
             # another temporal filter. A new filter based on the value of
             # the granularity will be added later in the code.
             # In practice, this is replacing the previous default temporal 
filter.
+            if is_adhoc_column(filter_to_remove):  # type: ignore
+                filter_to_remove = filter_to_remove.get("sqlExpression")
+
             if filter_to_remove:
                 query_object.filter = [
                     filter
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index c7ea336ded..79203256f1 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1007,6 +1007,8 @@ class SqlaTable(
                     qry = sa.select([sqla_column]).limit(1).select_from(tbl)
                     sql = self.database.compile_sqla_query(qry)
                     col_desc = get_columns_description(self.database, sql)
+                    if not col_desc:
+                        raise SupersetGenericDBErrorException("Column not 
found")
                     is_dttm = col_desc[0]["is_dttm"]  # type: ignore
                 except SupersetGenericDBErrorException as ex:
                     raise ColumnNotFoundException(message=str(ex)) from ex
diff --git a/tests/integration_tests/charts/data/api_tests.py 
b/tests/integration_tests/charts/data/api_tests.py
index dc82026986..ab91cce55e 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -51,6 +51,7 @@ from superset.models.slice import Slice
 from superset.superset_typing import AdhocColumn
 from superset.utils.core import (
     AnnotationType,
+    backend,
     get_example_default_schema,
     AdhocMetricExpressionType,
     ExtraFiltersReasonType,
@@ -943,6 +944,71 @@ class TestGetChartDataApi(BaseTestChartDataApi):
         assert data["result"][0]["status"] == "success"
         assert data["result"][0]["rowcount"] == 2
 
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_chart_data_get_with_x_axis_using_custom_sql(self):
+        """
+        Chart data API: Test GET endpoint
+        """
+        chart = db.session.query(Slice).filter_by(slice_name="Genders").one()
+        chart.query_context = json.dumps(
+            {
+                "datasource": {"id": chart.table.id, "type": "table"},
+                "force": False,
+                "queries": [
+                    {
+                        "time_range": "1900-01-01T00:00:00 : 
2000-01-01T00:00:00",
+                        "granularity": "ds",
+                        "filters": [
+                            {"col": "ds", "op": "TEMPORAL_RANGE", "val": "No 
filter"}
+                        ],
+                        "extras": {
+                            "having": "",
+                            "where": "",
+                        },
+                        "applied_time_extras": {},
+                        "columns": [
+                            {
+                                "columnType": "BASE_AXIS",
+                                "datasourceWarning": False,
+                                "expressionType": "SQL",
+                                "label": "My column",
+                                "sqlExpression": "ds",
+                                "timeGrain": "P1W",
+                            }
+                        ],
+                        "metrics": ["sum__num"],
+                        "orderby": [["sum__num", False]],
+                        "annotation_layers": [],
+                        "row_limit": 50000,
+                        "timeseries_limit": 0,
+                        "order_desc": True,
+                        "url_params": {},
+                        "custom_params": {},
+                        "custom_form_data": {},
+                    }
+                ],
+                "form_data": {
+                    "x_axis": {
+                        "datasourceWarning": False,
+                        "expressionType": "SQL",
+                        "label": "My column",
+                        "sqlExpression": "ds",
+                    }
+                },
+                "result_format": "json",
+                "result_type": "full",
+            }
+        )
+        rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", 
"get_data")
+        assert rv.mimetype == "application/json"
+        data = json.loads(rv.data.decode("utf-8"))
+        assert data["result"][0]["status"] == "success"
+
+        if backend() == "presto":
+            assert data["result"][0]["rowcount"] == 41
+        else:
+            assert data["result"][0]["rowcount"] == 40
+
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_chart_data_get_forced(self):
         """

Reply via email to