OmairK commented on a change in pull request #9097:
URL: https://github.com/apache/airflow/pull/9097#discussion_r437471258



##########
File path: airflow/api_connexion/endpoints/pool_endpoint.py
##########
@@ -26,18 +30,33 @@ def delete_pool():
     raise NotImplementedError("Not implemented yet.")
 
 
-def get_pool():
+@provide_session
+def get_pool(pool_name, session):
     """
     Get a pool
     """
-    raise NotImplementedError("Not implemented yet.")
+    pool_id = pool_name
+    query = session.query(Pool)
+    pool = query.filter(Pool.pool == pool_id).one_or_none()
+
+    if pool is None:
+        raise NotFound("Pool not found")
+    return pool_schema.dump(pool)
 
 
-def get_pools():
+@provide_session
+def get_pools(session):
     """
     Get all pools
     """
-    raise NotImplementedError("Not implemented yet.")
+    offset = request.args.get(parameters.page_offset, 0)
+    limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
+
+    query = session.query(Pool)
+    query = query.offset(offset).limit(limit)
+
+    pools = query.all()
+    return pool_collection_schema.dump(PoolCollection(pools=pools, 
total_entries=len(pools))).data

Review comment:
       Thanks.
   Here is the change `ce9b79f` 

##########
File path: tests/api_connexion/endpoints/test_pool_endpoint.py
##########
@@ -16,51 +16,143 @@
 # under the License.
 import unittest
 
-import pytest
-
+from airflow.models.pool import Pool
+from airflow.utils.session import provide_session
 from airflow.www import app
+from tests.test_utils.db import clear_db_pools
 
 
-class TestPoolEndpoint(unittest.TestCase):
+class TestBasePoolEndpoints(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
+        super().setUp()
+        clear_db_pools()
 
-
-class TestDeletePool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.delete("/api/v1/pools/TEST_POOL_NAME")
-        assert response.status_code == 204
+    def tearDown(self) -> None:
+        clear_db_pools()
 
 
-class TestGetPool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools/TEST_POOL_NAME")
+class TestGetPools(TestBasePoolEndpoints):
+    @provide_session
+    def test_response_200(self, session):
+        pool_model = Pool(pool="test_pool_a", slots=3)
+        session.add(pool_model)
+        session.commit()
+        result = session.query(Pool).all()
+        assert len(result) == 2  # accounts for the default pool as well
+        response = self.client.get("/api/v1/pools")
         assert response.status_code == 200
+        self.assertEqual(
+            {"pools": [{"name": "default_pool",
+                        "slots": 128,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 128,
+                        },
+                       {"name": "test_pool_a",
+                        "slots": 3,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 3,
+                        },
+                       ],
+                "total_entries": 2,
+             },
+            response.json,
+        )
 
-
-class TestGetPools(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools")
+    @provide_session
+    def test_limit(self, session):
+        pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(120)]
+        session.add_all(pools)
+        session.commit()
+        result = session.query(Pool).count()
+        self.assertEqual(result, 121)  # accounts for default pool as well
+        # Test for limit under 100
+        response = self.client.get("/api/v1/pools?limit=10")
+        assert response.status_code == 200
+        self.assertEqual(response.json.get("total_entries"), 10)

Review comment:
       Here is the change ce9b79f 

##########
File path: tests/api_connexion/endpoints/test_pool_endpoint.py
##########
@@ -16,51 +16,143 @@
 # under the License.
 import unittest
 
-import pytest
-
+from airflow.models.pool import Pool
+from airflow.utils.session import provide_session
 from airflow.www import app
+from tests.test_utils.db import clear_db_pools
 
 
-class TestPoolEndpoint(unittest.TestCase):
+class TestBasePoolEndpoints(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
+        super().setUp()
+        clear_db_pools()
 
-
-class TestDeletePool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.delete("/api/v1/pools/TEST_POOL_NAME")
-        assert response.status_code == 204
+    def tearDown(self) -> None:
+        clear_db_pools()
 
 
-class TestGetPool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools/TEST_POOL_NAME")
+class TestGetPools(TestBasePoolEndpoints):
+    @provide_session
+    def test_response_200(self, session):
+        pool_model = Pool(pool="test_pool_a", slots=3)
+        session.add(pool_model)
+        session.commit()
+        result = session.query(Pool).all()
+        assert len(result) == 2  # accounts for the default pool as well
+        response = self.client.get("/api/v1/pools")
         assert response.status_code == 200
+        self.assertEqual(
+            {"pools": [{"name": "default_pool",
+                        "slots": 128,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 128,
+                        },
+                       {"name": "test_pool_a",
+                        "slots": 3,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 3,
+                        },
+                       ],
+                "total_entries": 2,
+             },
+            response.json,
+        )
 
-
-class TestGetPools(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools")
+    @provide_session
+    def test_limit(self, session):
+        pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(120)]
+        session.add_all(pools)
+        session.commit()
+        result = session.query(Pool).count()
+        self.assertEqual(result, 121)  # accounts for default pool as well
+        # Test for limit under 100
+        response = self.client.get("/api/v1/pools?limit=10")
+        assert response.status_code == 200
+        self.assertEqual(response.json.get("total_entries"), 10)
+        # Test for limit over 100
+        response = self.client.get("/api/v1/pools?limit=110")
         assert response.status_code == 200
+        self.assertEqual(response.json.get('total_entries'), 100)

Review comment:
       Here is the change ce9b79f 

##########
File path: tests/api_connexion/endpoints/test_pool_endpoint.py
##########
@@ -16,51 +16,143 @@
 # under the License.
 import unittest
 
-import pytest
-
+from airflow.models.pool import Pool
+from airflow.utils.session import provide_session
 from airflow.www import app
+from tests.test_utils.db import clear_db_pools
 
 
-class TestPoolEndpoint(unittest.TestCase):
+class TestBasePoolEndpoints(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
+        super().setUp()
+        clear_db_pools()
 
-
-class TestDeletePool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.delete("/api/v1/pools/TEST_POOL_NAME")
-        assert response.status_code == 204
+    def tearDown(self) -> None:
+        clear_db_pools()
 
 
-class TestGetPool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools/TEST_POOL_NAME")
+class TestGetPools(TestBasePoolEndpoints):
+    @provide_session
+    def test_response_200(self, session):
+        pool_model = Pool(pool="test_pool_a", slots=3)
+        session.add(pool_model)
+        session.commit()
+        result = session.query(Pool).all()
+        assert len(result) == 2  # accounts for the default pool as well
+        response = self.client.get("/api/v1/pools")
         assert response.status_code == 200
+        self.assertEqual(
+            {"pools": [{"name": "default_pool",
+                        "slots": 128,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 128,
+                        },
+                       {"name": "test_pool_a",
+                        "slots": 3,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 3,
+                        },
+                       ],
+                "total_entries": 2,
+             },
+            response.json,
+        )
 
-
-class TestGetPools(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools")
+    @provide_session
+    def test_limit(self, session):
+        pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(120)]
+        session.add_all(pools)
+        session.commit()
+        result = session.query(Pool).count()
+        self.assertEqual(result, 121)  # accounts for default pool as well
+        # Test for limit under 100
+        response = self.client.get("/api/v1/pools?limit=10")
+        assert response.status_code == 200
+        self.assertEqual(response.json.get("total_entries"), 10)
+        # Test for limit over 100
+        response = self.client.get("/api/v1/pools?limit=110")
         assert response.status_code == 200
+        self.assertEqual(response.json.get('total_entries'), 100)

Review comment:
       Here is the change ~ce9b79f~  36d3a3f

##########
File path: tests/api_connexion/endpoints/test_pool_endpoint.py
##########
@@ -16,51 +16,143 @@
 # under the License.
 import unittest
 
-import pytest
-
+from airflow.models.pool import Pool
+from airflow.utils.session import provide_session
 from airflow.www import app
+from tests.test_utils.db import clear_db_pools
 
 
-class TestPoolEndpoint(unittest.TestCase):
+class TestBasePoolEndpoints(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
+        super().setUp()
+        clear_db_pools()
 
-
-class TestDeletePool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.delete("/api/v1/pools/TEST_POOL_NAME")
-        assert response.status_code == 204
+    def tearDown(self) -> None:
+        clear_db_pools()
 
 
-class TestGetPool(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools/TEST_POOL_NAME")
+class TestGetPools(TestBasePoolEndpoints):
+    @provide_session
+    def test_response_200(self, session):
+        pool_model = Pool(pool="test_pool_a", slots=3)
+        session.add(pool_model)
+        session.commit()
+        result = session.query(Pool).all()
+        assert len(result) == 2  # accounts for the default pool as well
+        response = self.client.get("/api/v1/pools")
         assert response.status_code == 200
+        self.assertEqual(
+            {"pools": [{"name": "default_pool",
+                        "slots": 128,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 128,
+                        },
+                       {"name": "test_pool_a",
+                        "slots": 3,
+                        "occupied_slots": 0,
+                        "running_slots": 0,
+                        "queued_slots": 0,
+                        "open_slots": 3,
+                        },
+                       ],
+                "total_entries": 2,
+             },
+            response.json,
+        )
 
-
-class TestGetPools(TestPoolEndpoint):
-    @pytest.mark.skip(reason="Not implemented yet")
-    def test_should_response_200(self):
-        response = self.client.get("/api/v1/pools")
+    @provide_session
+    def test_limit(self, session):
+        pools = [Pool(pool=f"test_pool{i}", slots=1) for i in range(120)]
+        session.add_all(pools)
+        session.commit()
+        result = session.query(Pool).count()
+        self.assertEqual(result, 121)  # accounts for default pool as well
+        # Test for limit under 100
+        response = self.client.get("/api/v1/pools?limit=10")
+        assert response.status_code == 200
+        self.assertEqual(response.json.get("total_entries"), 10)

Review comment:
       Here is the change ~ce9b79f~  36d3a3f

##########
File path: airflow/api_connexion/endpoints/pool_endpoint.py
##########
@@ -26,18 +30,33 @@ def delete_pool():
     raise NotImplementedError("Not implemented yet.")
 
 
-def get_pool():
+@provide_session
+def get_pool(pool_name, session):
     """
     Get a pool
     """
-    raise NotImplementedError("Not implemented yet.")
+    pool_id = pool_name
+    query = session.query(Pool)
+    pool = query.filter(Pool.pool == pool_id).one_or_none()
+
+    if pool is None:
+        raise NotFound("Pool not found")
+    return pool_schema.dump(pool)
 
 
-def get_pools():
+@provide_session
+def get_pools(session):
     """
     Get all pools
     """
-    raise NotImplementedError("Not implemented yet.")
+    offset = request.args.get(parameters.page_offset, 0)
+    limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
+
+    query = session.query(Pool)
+    query = query.offset(offset).limit(limit)
+
+    pools = query.all()
+    return pool_collection_schema.dump(PoolCollection(pools=pools, 
total_entries=len(pools))).data

Review comment:
       Thanks.
   Here is the change ~ce9b79f~  36d3a3f




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to