This is an automated email from the ASF dual-hosted git repository.
dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 365770e feat: request ids on API related endpoints (#12663)
365770e is described below
commit 365770e7c3326db5f86d275dbd2bb13526359b4f
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Wed Jan 27 20:24:49 2021 +0000
feat: request ids on API related endpoints (#12663)
* feat: request ids on API related endpoints
* rename ids to include_ids
---
superset/charts/schemas.py | 5 +++-
superset/views/base_api.py | 64 +++++++++++++++++++++++++++++++++++-----------
tests/base_api_tests.py | 58 +++++++++++++++++++++++++++++++++--------
3 files changed, 101 insertions(+), 26 deletions(-)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index d62707e..c44ae75 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -110,7 +110,10 @@ openapi_spec_methods_override = {
}
},
"related": {
- "get": {"description": "Get a list of all possible owners for a
chart."}
+ "get": {
+ "description": "Get a list of all possible owners for a chart. "
+ "Use `owners` has the `column_name` parameter"
+ }
},
}
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 3e11f92..2956058 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -46,6 +46,7 @@ get_related_schema = {
"properties": {
"page_size": {"type": "integer"},
"page": {"type": "integer"},
+ "include_ids": {"type": "array", "items": {"type": "integer"}},
"filter": {"type": "string"},
},
}
@@ -213,7 +214,10 @@ class BaseSupersetModelRestApi(ModelRestApi):
super().__init__()
def add_apispec_components(self, api_spec: APISpec) -> None:
-
+ """
+ Adds extra OpenApi schema spec components, these are declared
+ on the `openapi_spec_component_schemas` class property
+ """
for schema in self.openapi_spec_component_schemas:
try:
api_spec.components.schema(
@@ -271,6 +275,40 @@ class BaseSupersetModelRestApi(ModelRestApi):
)
return filters
+ def _get_text_for_model(self, model: Model, column_name: str) -> str:
+ if column_name in self.text_field_rel_fields:
+ model_column_name = self.text_field_rel_fields.get(column_name)
+ if model_column_name:
+ return getattr(model, model_column_name)
+ return str(model)
+
+ def _get_result_from_rows(
+ self, datamodel: SQLAInterface, rows: List[Model], column_name: str
+ ) -> List[Dict[str, Any]]:
+ return [
+ {
+ "value": datamodel.get_pk_value(row),
+ "text": self._get_text_for_model(row, column_name),
+ }
+ for row in rows
+ ]
+
+ def _add_extra_ids_to_result(
+ self,
+ datamodel: SQLAInterface,
+ column_name: str,
+ ids: List[int],
+ result: List[Dict[str, Any]],
+ ) -> None:
+ if ids:
+ # Filter out already present values on the result
+ values = [row["value"] for row in result]
+ ids = [id_ for id_ in ids if id_ not in values]
+ pk_col = datamodel.get_pk()
+ # Fetch requested values from ids
+ extra_rows =
db.session.query(datamodel.obj).filter(pk_col.in_(ids)).all()
+ result += self._get_result_from_rows(datamodel, extra_rows,
column_name)
+
def incr_stats(self, action: str, func_name: str) -> None:
"""
Proxy function for statsd.incr to impose a key structure for REST API's
@@ -424,18 +462,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
500:
$ref: '#/components/responses/500'
"""
-
- def get_text_for_model(model: Model) -> str:
- if column_name in self.text_field_rel_fields:
- model_column_name = self.text_field_rel_fields.get(column_name)
- if model_column_name:
- return getattr(model, model_column_name)
- return str(model)
-
if column_name not in self.allowed_rel_fields:
self.incr_stats("error", self.related.__name__)
return self.response_404()
args = kwargs.get("rison", {})
+
# handle pagination
page, page_size = self._handle_page_args(args)
try:
@@ -452,15 +483,18 @@ class BaseSupersetModelRestApi(ModelRestApi):
# handle filters
filters = self._get_related_filter(datamodel, column_name,
args.get("filter"))
# Make the query
- count, values = datamodel.query(
+ _, rows = datamodel.query(
filters, order_column, order_direction, page=page,
page_size=page_size
)
+
# produce response
- result = [
- {"value": datamodel.get_pk_value(value), "text":
get_text_for_model(value)}
- for value in values
- ]
- return self.response(200, count=count, result=result)
+ result = self._get_result_from_rows(datamodel, rows, column_name)
+
+ # If ids are specified make sure we fetch and include them on the
response
+ ids = args.get("include_ids")
+ self._add_extra_ids_to_result(datamodel, column_name, ids, result)
+
+ return self.response(200, count=len(result), result=result)
@expose("/distinct/<column_name>", methods=["GET"])
@protect()
diff --git a/tests/base_api_tests.py b/tests/base_api_tests.py
index 3dd21dc..f23b01e 100644
--- a/tests/base_api_tests.py
+++ b/tests/base_api_tests.py
@@ -184,48 +184,86 @@ class ApiOwnersTestCaseMixin:
def test_get_related_owners(self):
"""
- API: Test get related owners
+ API: Test get related owners
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owners"
rv = self.client.get(uri)
- self.assertEqual(rv.status_code, 200)
+ assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()
expected_users = [str(user) for user in users]
- self.assertEqual(response["count"], len(users))
+ assert response["count"] == len(users)
# This needs to be implemented like this, because ordering varies
between
# postgres and mysql
response_users = [result["text"] for result in response["result"]]
for expected_user in expected_users:
- self.assertIn(expected_user, response_users)
+ assert expected_user in response_users
def test_get_filter_related_owners(self):
"""
- API: Test get filter related owners
+ API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "gamma"}
uri =
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
- self.assertEqual(rv.status_code, 200)
+ assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(3, response["count"])
+ assert 3 == response["count"]
sorted_results = sorted(response["result"], key=lambda value:
value["text"])
expected_results = [
{"text": "gamma user", "value": 2},
{"text": "gamma2 user", "value": 3},
{"text": "gamma_sqllab user", "value": 4},
]
- self.assertEqual(expected_results, sorted_results)
+ assert expected_results == sorted_results
+
+ def test_get_ids_related_owners(self):
+ """
+ API: Test get filter related owners
+ """
+ self.login(username="admin")
+ argument = {"filter": "gamma_sqllab", "include_ids": [2]}
+ uri =
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
+
+ rv = self.client.get(uri)
+ response = json.loads(rv.data.decode("utf-8"))
+ assert rv.status_code == 200
+ assert 2 == response["count"]
+ sorted_results = sorted(response["result"], key=lambda value:
value["text"])
+ expected_results = [
+ {"text": "gamma user", "value": 2},
+ {"text": "gamma_sqllab user", "value": 4},
+ ]
+ assert expected_results == sorted_results
+
+ def test_get_repeated_ids_related_owners(self):
+ """
+ API: Test get filter related owners
+ """
+ self.login(username="admin")
+ argument = {"filter": "gamma_sqllab", "include_ids": [2, 4]}
+ uri =
f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
+
+ rv = self.client.get(uri)
+ response = json.loads(rv.data.decode("utf-8"))
+ assert rv.status_code == 200
+ assert 2 == response["count"]
+ sorted_results = sorted(response["result"], key=lambda value:
value["text"])
+ expected_results = [
+ {"text": "gamma user", "value": 2},
+ {"text": "gamma_sqllab user", "value": 4},
+ ]
+ assert expected_results == sorted_results
def test_get_related_fail(self):
"""
- API: Test get related fail
+ API: Test get related fail
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owner"
rv = self.client.get(uri)
- self.assertEqual(rv.status_code, 404)
+ assert rv.status_code == 404