stale[bot] closed pull request #3858: [AIRFLOW-2929] Add get and set for pool class in models URL: https://github.com/apache/incubator-airflow/pull/3858
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 4b46921e64..194d6d225e 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -18,7 +18,7 @@ # under the License. from airflow.api.client import api_client -from airflow.api.common.experimental import pool +from airflow.models import Pool from airflow.api.common.experimental import trigger_dag from airflow.api.common.experimental import delete_dag @@ -38,16 +38,16 @@ def delete_dag(self, dag_id): return "Removed {} record(s)".format(count) def get_pool(self, name): - p = pool.get_pool(name=name) + p = Pool.get_pool(name=name) return p.pool, p.slots, p.description def get_pools(self): - return [(p.pool, p.slots, p.description) for p in pool.get_pools()] + return [(p.pool, p.slots, p.description) for p in Pool.get_pools()] def create_pool(self, name, slots, description): - p = pool.create_pool(name=name, slots=slots, description=description) + p = Pool.create_pool(name=name, slots=slots, description=description) return p.pool, p.slots, p.description def delete_pool(self, name): - p = pool.delete_pool(name=name) + p = Pool.delete_pool(name=name) return p.pool, p.slots, p.description diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 4ff1ae3679..c638c3c233 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -57,7 +57,7 @@ from airflow.executors import GetDefaultExecutor from airflow.models import (DagModel, DagBag, TaskInstance, DagPickle, DagRun, Variable, DagStat, - Connection, DAG) + Connection, DAG, Pool) from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import cli as cli_utils @@ -263,29 +263,32 @@ def pool(args): log = LoggingMixin().log def _tabulate(pools): - return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], + return "\n%s" % tabulate([(pool.pool, + pool.slots, + pool.description) for pool in pools], + ['Pool', 'Slots', 'Description'], tablefmt="fancy_grid") try: imp = getattr(args, 'import') if args.get is not None: - pools = [api_client.get_pool(name=args.get)] + pools = [Pool.get_pool(name=args.get)] elif args.set: - pools = [api_client.create_pool(name=args.set[0], - slots=args.set[1], - description=args.set[2])] + pools = [Pool.create_pool(name=args.set[0], + slots=args.set[1], + description=args.set[2])] elif args.delete: - pools = [api_client.delete_pool(name=args.delete)] + pools = [Pool.delete_pool(name=args.delete)] elif imp: if os.path.exists(imp): pools = pool_import_helper(imp) else: print("Missing pools file.") - pools = api_client.get_pools() + pools = Pool.get_pools() elif args.export: pools = pool_export_helper(args.export) else: - pools = api_client.get_pools() + pools = Pool.get_pools() except (AirflowException, IOError) as err: log.error(err) else: @@ -305,9 +308,9 @@ def pool_import_helper(filepath): n = 0 for k, v in d.items(): if isinstance(v, dict) and len(v) == 2: - pools.append(api_client.create_pool(name=k, - slots=v["slots"], - description=v["description"])) + pools.append(Pool.create_pool(name=k, + slots=v["slots"], + description=v["description"])) n += 1 else: pass @@ -320,9 +323,9 @@ def pool_import_helper(filepath): def pool_export_helper(filepath): pool_dict = {} - pools = api_client.get_pools() + pools = Pool.get_pools() for pool in pools: - pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]} + pool_dict[pool.pool] = {"slots": pool.slots, "description": pool.description} with open(filepath, 'w') as poolfile: poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4)) print("{} pools successfully exported to {}".format(len(pool_dict), filepath)) diff --git a/airflow/models.py b/airflow/models.py index 4aa00074f5..47edcca236 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -70,7 +70,8 @@ from airflow.executors import GetDefaultExecutor, LocalExecutor from airflow import configuration from airflow.exceptions import ( - AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout + AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout, + AirflowBadRequest, PoolNotFound ) from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.lineage import apply_lineage, prepare_lineage @@ -5260,6 +5261,60 @@ def open_slots(self, session): queued_slots = self.queued_slots(session=session) return self.slots - used_slots - queued_slots + @classmethod + @provide_session + def get_pool(cls, name, session): + """Get pool by a given name.""" + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + pool = session.query(cls).filter(cls.pool == name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + return pool + + @classmethod + @provide_session + def get_pools(cls, session): + """Get all pools.""" + return [pool for pool in session.query(cls).all()] + + @classmethod + @provide_session + def create_pool(cls, name, slots, description, session): + """Create a pool with a given parameters.""" + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + try: + slots = int(slots) + except ValueError: + raise AirflowBadRequest("Bad value for `slots`: %s" % slots) + + session.expire_on_commit = False + pool = session.query(cls).filter_by(pool=name).first() + if pool is None: + pool = Pool(pool=name, slots=slots, description=description) + session.add(pool) + else: + pool.slots = slots + pool.description = description + session.commit() + return pool + + @classmethod + @provide_session + def delete_pool(cls, name, session): + """Delete pool by a given name.""" + if not (name and name.strip()): + raise AirflowBadRequest("Pool name shouldn't be empty") + + pool = session.query(cls).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + + session.delete(pool) + session.commit() + return pool + class SlaMiss(Base): """ ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services