This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 365770e  feat: request ids on API related endpoints (#12663)
365770e is described below

commit 365770e7c3326db5f86d275dbd2bb13526359b4f
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Wed Jan 27 20:24:49 2021 +0000

    feat: request ids on API related endpoints (#12663)
    
    * feat: request ids on API related endpoints
    
    * rename ids to include_ids
---
 superset/charts/schemas.py |  5 +++-
 superset/views/base_api.py | 64 +++++++++++++++++++++++++++++++++++-----------
 tests/base_api_tests.py    | 58 +++++++++++++++++++++++++++++++++--------
 3 files changed, 101 insertions(+), 26 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index d62707e..c44ae75 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -110,7 +110,10 @@ openapi_spec_methods_override = {
         }
     },
     "related": {
-        "get": {"description": "Get a list of all possible owners for a 
chart."}
+        "get": {
+            "description": "Get a list of all possible owners for a chart. "
+            "Use `owners` has the `column_name` parameter"
+        }
     },
 }
 
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 3e11f92..2956058 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -46,6 +46,7 @@ get_related_schema = {
     "properties": {
         "page_size": {"type": "integer"},
         "page": {"type": "integer"},
+        "include_ids": {"type": "array", "items": {"type": "integer"}},
         "filter": {"type": "string"},
     },
 }
@@ -213,7 +214,10 @@ class BaseSupersetModelRestApi(ModelRestApi):
         super().__init__()
 
     def add_apispec_components(self, api_spec: APISpec) -> None:
-
+        """
+        Adds extra OpenApi schema spec components, these are declared
+        on the `openapi_spec_component_schemas` class property
+        """
         for schema in self.openapi_spec_component_schemas:
             try:
                 api_spec.components.schema(
@@ -271,6 +275,40 @@ class BaseSupersetModelRestApi(ModelRestApi):
             )
         return filters
 
+    def _get_text_for_model(self, model: Model, column_name: str) -> str:
+        if column_name in self.text_field_rel_fields:
+            model_column_name = self.text_field_rel_fields.get(column_name)
+            if model_column_name:
+                return getattr(model, model_column_name)
+        return str(model)
+
+    def _get_result_from_rows(
+        self, datamodel: SQLAInterface, rows: List[Model], column_name: str
+    ) -> List[Dict[str, Any]]:
+        return [
+            {
+                "value": datamodel.get_pk_value(row),
+                "text": self._get_text_for_model(row, column_name),
+            }
+            for row in rows
+        ]
+
+    def _add_extra_ids_to_result(
+        self,
+        datamodel: SQLAInterface,
+        column_name: str,
+        ids: List[int],
+        result: List[Dict[str, Any]],
+    ) -> None:
+        if ids:
+            # Filter out already present values on the result
+            values = [row["value"] for row in result]
+            ids = [id_ for id_ in ids if id_ not in values]
+            pk_col = datamodel.get_pk()
+            # Fetch requested values from ids
+            extra_rows = 
db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all()
+            result += self._get_result_from_rows(datamodel, extra_rows, 
column_name)
+
     def incr_stats(self, action: str, func_name: str) -> None:
         """
         Proxy function for statsd.incr to impose a key structure for REST API's
@@ -424,18 +462,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
             500:
               $ref: '#/components/responses/500'
         """
-
-        def get_text_for_model(model: Model) -> str:
-            if column_name in self.text_field_rel_fields:
-                model_column_name = self.text_field_rel_fields.get(column_name)
-                if model_column_name:
-                    return getattr(model, model_column_name)
-            return str(model)
-
         if column_name not in self.allowed_rel_fields:
             self.incr_stats("error", self.related.__name__)
             return self.response_404()
         args = kwargs.get("rison", {})
+
         # handle pagination
         page, page_size = self._handle_page_args(args)
         try:
@@ -452,15 +483,18 @@ class BaseSupersetModelRestApi(ModelRestApi):
         # handle filters
         filters = self._get_related_filter(datamodel, column_name, 
args.get("filter"))
         # Make the query
-        count, values = datamodel.query(
+        _, rows = datamodel.query(
             filters, order_column, order_direction, page=page, 
page_size=page_size
         )
+
         # produce response
-        result = [
-            {"value": datamodel.get_pk_value(value), "text": 
get_text_for_model(value)}
-            for value in values
-        ]
-        return self.response(200, count=count, result=result)
+        result = self._get_result_from_rows(datamodel, rows, column_name)
+
+        # If ids are specified make sure we fetch and include them on the 
response
+        ids = args.get("include_ids")
+        self._add_extra_ids_to_result(datamodel, column_name, ids, result)
+
+        return self.response(200, count=len(result), result=result)
 
     @expose("/distinct/<column_name>", methods=["GET"])
     @protect()
diff --git a/tests/base_api_tests.py b/tests/base_api_tests.py
index 3dd21dc..f23b01e 100644
--- a/tests/base_api_tests.py
+++ b/tests/base_api_tests.py
@@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin:
 
     def test_get_related_owners(self):
         """
-            API: Test get related owners
+        API: Test get related owners
         """
         self.login(username="admin")
         uri = f"api/v1/{self.resource_name}/related/owners"
         rv = self.client.get(uri)
-        self.assertEqual(rv.status_code, 200)
+        assert rv.status_code == 200
         response = json.loads(rv.data.decode("utf-8"))
         users = db.session.query(security_manager.user_model).all()
         expected_users = [str(user) for user in users]
-        self.assertEqual(response["count"], len(users))
+        assert response["count"] == len(users)
         # This needs to be implemented like this, because ordering varies 
between
         # postgres and mysql
         response_users = [result["text"] for result in response["result"]]
         for expected_user in expected_users:
-            self.assertIn(expected_user, response_users)
+            assert expected_user in response_users
 
     def test_get_filter_related_owners(self):
         """
-            API: Test get filter related owners
+        API: Test get filter related owners
         """
         self.login(username="admin")
         argument = {"filter": "gamma"}
         uri = 
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
 
         rv = self.client.get(uri)
-        self.assertEqual(rv.status_code, 200)
+        assert rv.status_code == 200
         response = json.loads(rv.data.decode("utf-8"))
-        self.assertEqual(3, response["count"])
+        assert 3 == response["count"]
         sorted_results = sorted(response["result"], key=lambda value: 
value["text"])
         expected_results = [
             {"text": "gamma user", "value": 2},
             {"text": "gamma2 user", "value": 3},
             {"text": "gamma_sqllab user", "value": 4},
         ]
-        self.assertEqual(expected_results, sorted_results)
+        assert expected_results == sorted_results
+
+    def test_get_ids_related_owners(self):
+        """
+        API: Test get filter related owners
+        """
+        self.login(username="admin")
+        argument = {"filter": "gamma_sqllab", "include_ids": [2]}
+        uri = 
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
+
+        rv = self.client.get(uri)
+        response = json.loads(rv.data.decode("utf-8"))
+        assert rv.status_code == 200
+        assert 2 == response["count"]
+        sorted_results = sorted(response["result"], key=lambda value: 
value["text"])
+        expected_results = [
+            {"text": "gamma user", "value": 2},
+            {"text": "gamma_sqllab user", "value": 4},
+        ]
+        assert expected_results == sorted_results
+
+    def test_get_repeated_ids_related_owners(self):
+        """
+        API: Test get filter related owners
+        """
+        self.login(username="admin")
+        argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]}
+        uri = 
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
+
+        rv = self.client.get(uri)
+        response = json.loads(rv.data.decode("utf-8"))
+        assert rv.status_code == 200
+        assert 2 == response["count"]
+        sorted_results = sorted(response["result"], key=lambda value: 
value["text"])
+        expected_results = [
+            {"text": "gamma user", "value": 2},
+            {"text": "gamma_sqllab user", "value": 4},
+        ]
+        assert expected_results == sorted_results
 
     def test_get_related_fail(self):
         """
-            API: Test get related fail
+        API: Test get related fail
         """
         self.login(username="admin")
         uri = f"api/v1/{self.resource_name}/related/owner"
 
         rv = self.client.get(uri)
-        self.assertEqual(rv.status_code, 404)
+        assert rv.status_code == 404

Reply via email to