This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 1.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit c7284fe2421caf8b088c0e71d3d40b1bd38528d7
Author: Duy Nguyen Hoang <[email protected]>
AuthorDate: Tue Jan 19 23:15:16 2021 +0700

    fix: error while parsing invalid json form_data (#12586)
    
    * Fix error while parsing invalid json form_data
    
    * Refine error returned
---
 superset/charts/api.py    | 11 ++++++++---
 superset/views/utils.py   | 13 ++++++++++---
 tests/charts/api_tests.py | 17 ++++++++++++++++-
 tests/utils_tests.py      | 16 +++++++++++++++-
 4 files changed, 49 insertions(+), 8 deletions(-)

diff --git a/superset/charts/api.py b/superset/charts/api.py
index 2d46cce..cce76da 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -534,13 +534,18 @@ class ChartRestApi(BaseSupersetModelRestApi):
             500:
               $ref: '#/components/responses/500'
         """
+        json_body = None
         if request.is_json:
             json_body = request.json
         elif request.form.get("form_data"):
             # CSV export submits regular form data
-            json_body = json.loads(request.form["form_data"])
-        else:
-            return self.response_400(message="Request is not JSON")
+            try:
+                json_body = json.loads(request.form["form_data"])
+            except (TypeError, json.JSONDecodeError):
+                json_body = None
+
+        if json_body is None:
+            return self.response_400(message=_("Request is not JSON"))
 
         try:
             command = ChartDataCommand()
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 28104aa..3ea253c 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -126,6 +126,13 @@ def get_viz(
     return viz_obj
 
 
+def loads_request_json(request_json_data: str) -> Dict[Any, Any]:
+    try:
+        return json.loads(request_json_data)
+    except (TypeError, json.JSONDecodeError):
+        return {}
+
+
 def get_form_data(
     slice_id: Optional[int] = None, use_slice_data: bool = False
 ) -> Tuple[Dict[str, Any], Optional[Slice]]:
@@ -141,10 +148,10 @@ def get_form_data(
     if request_json_data:
         form_data.update(request_json_data)
     if request_form_data:
-        form_data.update(json.loads(request_form_data))
+        form_data.update(loads_request_json(request_form_data))
     # request params can overwrite the body
     if request_args_data:
-        form_data.update(json.loads(request_args_data))
+        form_data.update(loads_request_json(request_args_data))
 
     # Fallback to using the Flask globals (used for cache warmup) if defined.
     if not form_data and hasattr(g, "form_data"):
@@ -157,7 +164,7 @@ def get_form_data(
             url_str = parse.unquote_plus(
                 saved_url.url.split("?")[1][10:], encoding="utf-8"
             )
-            url_form_data = json.loads(url_str)
+            url_form_data = loads_request_json(url_str)
             # allow form_date in request override saved url
             url_form_data.update(form_data)
             form_data = url_form_data
diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py
index f7ba39a..8e22074 100644
--- a/tests/charts/api_tests.py
+++ b/tests/charts/api_tests.py
@@ -1176,6 +1176,21 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         self.assertEqual(rv.status_code, 400)
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_chart_data_invalid_form_data(self):
+        """
+        Chart data API: Test chart data with invalid form_data json
+        """
+        self.login(username="admin")
+        data = {"form_data": "NOT VALID JSON"}
+
+        rv = self.client.post(
+            CHART_DATA_URI, data=data, content_type="multipart/form-data"
+        )
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 400)
+        self.assertEqual(response["message"], "Request is not JSON")
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_chart_data_query_result_type(self):
         """
         Chart data API: Test chart data with query result format
@@ -1592,7 +1607,7 @@ class TestChartApi(SupersetTestCase, 
ApiOwnersTestCaseMixin):
         assert rv.status_code == 422
         assert response == {
             "message": {
-                "charts/imported_chart.yaml": "Chart already exists and 
`overwrite=true` was not passed",
+                "charts/imported_chart.yaml": "Chart already exists and 
`overwrite=true` was not passed"
             }
         }
 
diff --git a/tests/utils_tests.py b/tests/utils_tests.py
index 1c6e6b2..fbdf131 100644
--- a/tests/utils_tests.py
+++ b/tests/utils_tests.py
@@ -925,7 +925,7 @@ class TestUtils(SupersetTestCase):
 
             self.assertEqual(
                 form_data,
-                {"time_range_endpoints": 
get_time_range_endpoints(form_data={}),},
+                {"time_range_endpoints": 
get_time_range_endpoints(form_data={})},
             )
 
             self.assertEqual(slc, None)
@@ -994,6 +994,20 @@ class TestUtils(SupersetTestCase):
 
             self.assertEqual(slc, None)
 
+    def test_get_form_data_corrupted_json(self) -> None:
+        with app.test_request_context(
+            data={"form_data": "{x: '2324'}"},
+            query_string={"form_data": '{"baz": "bar"'},
+        ):
+            form_data, slc = get_form_data()
+
+            self.assertEqual(
+                form_data,
+                {"time_range_endpoints": 
get_time_range_endpoints(form_data={})},
+            )
+
+            self.assertEqual(slc, None)
+
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_log_this(self) -> None:
         # TODO: Add additional scenarios.

Reply via email to