This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a91a0bcc2 Add filter task upstream/downstream to grid view (#29885)
0a91a0bcc2 is described below

commit 0a91a0bcc26f402fcd3b09c83554d31908b2d847
Author: Brent Bovenzi <[email protected]>
AuthorDate: Mon Mar 13 10:38:17 2023 -0400

    Add filter task upstream/downstream to grid view (#29885)
    
    * Add filter task upstream/downstream to grid view
    
    * Update button text, move reset root to dropdown
---
 airflow/www/static/js/api/useGraphData.ts          | 34 +++++---
 airflow/www/static/js/api/useGridData.ts           | 29 ++++++-
 airflow/www/static/js/dag/details/FilterTasks.tsx  | 91 ++++++++++++++++++++++
 airflow/www/static/js/dag/details/graph/index.tsx  | 12 ++-
 airflow/www/static/js/dag/details/index.tsx        |  6 +-
 .../www/static/js/dag/details/taskInstance/Nav.tsx | 24 +-----
 .../static/js/dag/details/taskInstance/index.tsx   |  1 -
 airflow/www/static/js/dag/grid/ResetRoot.tsx       | 42 ----------
 airflow/www/static/js/dag/nav/FilterBar.tsx        |  8 +-
 airflow/www/static/js/dag/useFilters.test.tsx      | 69 +++++++++++++++-
 airflow/www/static/js/dag/useFilters.tsx           | 59 ++++++++++++++
 airflow/www/static/js/utils/graph.ts               | 20 ++++-
 airflow/www/views.py                               | 10 ++-
 13 files changed, 309 insertions(+), 96 deletions(-)

diff --git a/airflow/www/static/js/api/useGraphData.ts 
b/airflow/www/static/js/api/useGraphData.ts
index ab318fbb07..9c0f089e9c 100644
--- a/airflow/www/static/js/api/useGraphData.ts
+++ b/airflow/www/static/js/api/useGraphData.ts
@@ -22,12 +22,16 @@ import axios, { AxiosResponse } from "axios";
 
 import { getMetaValue } from "src/utils";
 import type { DepNode } from "src/types";
+import useFilters, {
+  FILTER_DOWNSTREAM_PARAM,
+  FILTER_UPSTREAM_PARAM,
+  ROOT_PARAM,
+} from "src/dag/useFilters";
 
 const DAG_ID_PARAM = "dag_id";
 
 const dagId = getMetaValue(DAG_ID_PARAM);
 const graphDataUrl = getMetaValue("graph_data_url");
-const urlRoot = getMetaValue("root");
 
 interface GraphData {
   edges: WebserverEdge[];
@@ -40,15 +44,23 @@ export interface WebserverEdge {
   targetId: string;
 }
 
-const useGraphData = () =>
-  useQuery("graphData", async () => {
-    const params = {
-      [DAG_ID_PARAM]: dagId,
-      root: urlRoot || undefined,
-      filter_upstream: true,
-      filter_downstream: true,
-    };
-    return axios.get<AxiosResponse, GraphData>(graphDataUrl, { params });
-  });
+const useGraphData = () => {
+  const {
+    filters: { root, filterDownstream, filterUpstream },
+  } = useFilters();
+
+  return useQuery(
+    ["graphData", root, filterUpstream, filterDownstream],
+    async () => {
+      const params = {
+        [DAG_ID_PARAM]: dagId,
+        [ROOT_PARAM]: root,
+        [FILTER_UPSTREAM_PARAM]: filterUpstream,
+        [FILTER_DOWNSTREAM_PARAM]: filterDownstream,
+      };
+      return axios.get<AxiosResponse, GraphData>(graphDataUrl, { params });
+    }
+  );
+};
 
 export default useGraphData;
diff --git a/airflow/www/static/js/api/useGridData.ts 
b/airflow/www/static/js/api/useGridData.ts
index 7be2e0d2cf..a0cbfd6b76 100644
--- a/airflow/www/static/js/api/useGridData.ts
+++ b/airflow/www/static/js/api/useGridData.ts
@@ -29,6 +29,9 @@ import useFilters, {
   RUN_STATE_PARAM,
   RUN_TYPE_PARAM,
   now,
+  FILTER_DOWNSTREAM_PARAM,
+  FILTER_UPSTREAM_PARAM,
+  ROOT_PARAM,
 } from "src/dag/useFilters";
 import type { Task, DagRun, RunOrdering } from "src/types";
 import { camelCase } from "lodash";
@@ -38,7 +41,6 @@ const DAG_ID_PARAM = "dag_id";
 // dagId comes from dag.html
 const dagId = getMetaValue(DAG_ID_PARAM);
 const gridDataUrl = getMetaValue("grid_data_url");
-const urlRoot = getMetaValue("root");
 
 export interface GridData {
   dagRuns: DagRun[];
@@ -68,14 +70,33 @@ const useGridData = () => {
   const { isRefreshOn, stopRefresh } = useAutoRefresh();
   const errorToast = useErrorToast();
   const {
-    filters: { baseDate, numRuns, runType, runState },
+    filters: {
+      baseDate,
+      numRuns,
+      runType,
+      runState,
+      root,
+      filterDownstream,
+      filterUpstream,
+    },
   } = useFilters();
 
   const query = useQuery(
-    ["gridData", baseDate, numRuns, runType, runState],
+    [
+      "gridData",
+      baseDate,
+      numRuns,
+      runType,
+      runState,
+      root,
+      filterUpstream,
+      filterDownstream,
+    ],
     async () => {
       const params = {
-        root: urlRoot || undefined,
+        [ROOT_PARAM]: root,
+        [FILTER_UPSTREAM_PARAM]: filterUpstream,
+        [FILTER_DOWNSTREAM_PARAM]: filterDownstream,
         [DAG_ID_PARAM]: dagId,
         [BASE_DATE_PARAM]: baseDate === now ? undefined : baseDate,
         [NUM_RUNS_PARAM]: numRuns,
diff --git a/airflow/www/static/js/dag/details/FilterTasks.tsx 
b/airflow/www/static/js/dag/details/FilterTasks.tsx
new file mode 100644
index 0000000000..f21e812b3a
--- /dev/null
+++ b/airflow/www/static/js/dag/details/FilterTasks.tsx
@@ -0,0 +1,91 @@
+/*!
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import React from "react";
+import {
+  Flex,
+  Button,
+  Menu,
+  MenuButton,
+  MenuItem,
+  MenuList,
+} from "@chakra-ui/react";
+import useFilters from "src/dag/useFilters";
+import { MdArrowDropDown } from "react-icons/md";
+
+interface Props {
+  taskId: string;
+}
+
+const FilterTasks = ({ taskId }: Props) => {
+  const {
+    filters: { root },
+    onFilterTasksChange,
+    resetRoot,
+  } = useFilters();
+
+  const onFilterUpstream = () =>
+    onFilterTasksChange({
+      root: taskId,
+      filterUpstream: true,
+      filterDownstream: false,
+    });
+
+  const onFilterDownstream = () =>
+    onFilterTasksChange({
+      root: taskId,
+      filterUpstream: false,
+      filterDownstream: true,
+    });
+
+  const onFilterAll = () =>
+    onFilterTasksChange({
+      root: taskId,
+      filterUpstream: true,
+      filterDownstream: true,
+    });
+
+  const label = "Filter upstream and/or downstream of a task";
+
+  return (
+    <Menu>
+      <MenuButton
+        as={Button}
+        variant="outline"
+        colorScheme="blue"
+        transition="all 0.2s"
+        title={label}
+        aria-label={label}
+      >
+        <Flex>
+          {!root ? "Filter Tasks " : "Clear Task Filter "}
+          <MdArrowDropDown size="16px" />
+        </Flex>
+      </MenuButton>
+      <MenuList>
+        <MenuItem onClick={onFilterUpstream}>Filter Upstream</MenuItem>
+        <MenuItem onClick={onFilterDownstream}>Filter Downstream</MenuItem>
+        <MenuItem onClick={onFilterAll}>Filter Upstream & Downstream</MenuItem>
+        {!!root && <MenuItem onClick={resetRoot}>Reset Root</MenuItem>}
+      </MenuList>
+    </Menu>
+  );
+};
+
+export default FilterTasks;
diff --git a/airflow/www/static/js/dag/details/graph/index.tsx 
b/airflow/www/static/js/dag/details/graph/index.tsx
index da92a06108..b25d689301 100644
--- a/airflow/www/static/js/dag/details/graph/index.tsx
+++ b/airflow/www/static/js/dag/details/graph/index.tsx
@@ -37,6 +37,7 @@ import { useOffsetTop } from "src/utils";
 import { useGraphLayout } from "src/utils/graph";
 import Tooltip from "src/components/Tooltip";
 import { useContainerRef } from "src/context/containerRef";
+import useFilters from "src/dag/useFilters";
 
 import Edge from "./Edge";
 import Node, { CustomNodeProps } from "./Node";
@@ -56,6 +57,10 @@ const Graph = ({ openGroupIds, onToggleGroups }: Props) => {
   const { data } = useGraphData();
   const [arrange, setArrange] = useState(data?.arrange || "LR");
 
+  const {
+    filters: { root, filterDownstream, filterUpstream },
+  } = useFilters();
+
   useEffect(() => {
     setArrange(data?.arrange || "LR");
   }, [data?.arrange]);
@@ -71,9 +76,14 @@ const Graph = ({ openGroupIds, onToggleGroups }: Props) => {
     data: { dagRuns, groups },
   } = useGridData();
   const { colors } = useTheme();
-  const { setCenter } = useReactFlow();
+  const { setCenter, setViewport } = useReactFlow();
   const latestDagRunId = dagRuns[dagRuns.length - 1]?.runId;
 
+  // Reset viewport when tasks are filtered
+  useEffect(() => {
+    setViewport({ x: 0, y: 0, zoom: 1 });
+  }, [root, filterDownstream, filterUpstream, setViewport]);
+
   const offsetTop = useOffsetTop(graphRef);
 
   let nodes: ReactFlowNode<CustomNodeProps>[] = [];
diff --git a/airflow/www/static/js/dag/details/index.tsx 
b/airflow/www/static/js/dag/details/index.tsx
index 6ae79363db..d9fe8d5890 100644
--- a/airflow/www/static/js/dag/details/index.tsx
+++ b/airflow/www/static/js/dag/details/index.tsx
@@ -45,6 +45,7 @@ import Graph from "./graph";
 import MappedInstances from "./taskInstance/MappedInstances";
 import Logs from "./taskInstance/Logs";
 import BackToTaskSummary from "./taskInstance/BackToTaskSummary";
+import FilterTasks from "./FilterTasks";
 
 const dagId = getMetaValue("dag_id")!;
 
@@ -145,7 +146,10 @@ const Details = ({ openGroupIds, onToggleGroups }: Props) 
=> {
 
   return (
     <Flex flexDirection="column" pl={3} height="100%">
-      <Header />
+      <Flex alignItems="center" justifyContent="space-between">
+        <Header />
+        <Flex>{taskId && runId && <FilterTasks taskId={taskId} />}</Flex>
+      </Flex>
       <Divider my={2} />
       <Tabs
         size="lg"
diff --git a/airflow/www/static/js/dag/details/taskInstance/Nav.tsx 
b/airflow/www/static/js/dag/details/taskInstance/Nav.tsx
index 192f4eda38..3541240ce3 100644
--- a/airflow/www/static/js/dag/details/taskInstance/Nav.tsx
+++ b/airflow/www/static/js/dag/details/taskInstance/Nav.tsx
@@ -22,23 +22,19 @@ import { Flex } from "@chakra-ui/react";
 
 import { getMetaValue, appendSearchParams } from "src/utils";
 import LinkButton from "src/components/LinkButton";
-import type { Task, DagRun } from "src/types";
+import type { Task } from "src/types";
 import URLSearchParamsWrapper from "src/utils/URLSearchParamWrapper";
 
 const dagId = getMetaValue("dag_id");
 const isK8sExecutor = getMetaValue("k8s_or_k8scelery_executor") === "True";
-const numRuns = getMetaValue("num_runs");
-const baseDate = getMetaValue("base_date");
 const taskInstancesUrl = getMetaValue("task_instances_list_url");
 const renderedK8sUrl = getMetaValue("rendered_k8s_url");
 const renderedTemplatesUrl = getMetaValue("rendered_templates_url");
 const xcomUrl = getMetaValue("xcom_url");
 const taskUrl = getMetaValue("task_url");
 const gridUrl = getMetaValue("grid_url");
-const gridUrlNoRoot = getMetaValue("grid_url_no_root");
 
 interface Props {
-  runId: DagRun["runId"];
   taskId: Task["id"];
   executionDate: string;
   operator?: string;
@@ -47,10 +43,7 @@ interface Props {
 }
 
 const Nav = forwardRef<HTMLDivElement, Props>(
-  (
-    { runId, taskId, executionDate, operator, isMapped = false, mapIndex },
-    ref
-  ) => {
+  ({ taskId, executionDate, operator, isMapped = false, mapIndex }, ref) => {
     if (!taskId) return null;
     const params = new URLSearchParamsWrapper({
       task_id: taskId,
@@ -70,23 +63,11 @@ const Nav = forwardRef<HTMLDivElement, Props>(
       execution_date: executionDate,
     }).toString();
 
-    const filterParams = new URLSearchParamsWrapper({
-      task_id: taskId,
-      dag_run_id: runId,
-      root: taskId,
-    });
-
     if (mapIndex !== undefined && mapIndex >= 0)
       listParams.append("_flt_0_map_index", mapIndex.toString());
-    if (baseDate) filterParams.append("base_date", baseDate);
-    if (numRuns) filterParams.append("num_runs", numRuns);
 
     const allInstancesLink = `${taskInstancesUrl}?${listParams.toString()}`;
 
-    const filterUpstreamLink = appendSearchParams(
-      gridUrlNoRoot,
-      filterParams.toString()
-    );
     const subDagLink = appendSearchParams(
       gridUrl.replace(dagId, `${dagId}.${taskId}`),
       subDagParams
@@ -116,7 +97,6 @@ const Nav = forwardRef<HTMLDivElement, Props>(
         >
           List Instances, all runs
         </LinkButton>
-        <LinkButton href={filterUpstreamLink}>Filter Upstream</LinkButton>
       </Flex>
     );
   }
diff --git a/airflow/www/static/js/dag/details/taskInstance/index.tsx 
b/airflow/www/static/js/dag/details/taskInstance/index.tsx
index e95389fc65..7d584aa07c 100644
--- a/airflow/www/static/js/dag/details/taskInstance/index.tsx
+++ b/airflow/www/static/js/dag/details/taskInstance/index.tsx
@@ -92,7 +92,6 @@ const TaskInstance = ({ taskId, runId, mapIndex }: Props) => {
       {!isGroup && (
         <TaskNav
           taskId={taskId}
-          runId={runId}
           isMapped={isMapped}
           mapIndex={mapIndex}
           executionDate={executionDate}
diff --git a/airflow/www/static/js/dag/grid/ResetRoot.tsx 
b/airflow/www/static/js/dag/grid/ResetRoot.tsx
deleted file mode 100644
index 6d1b870bf1..0000000000
--- a/airflow/www/static/js/dag/grid/ResetRoot.tsx
+++ /dev/null
@@ -1,42 +0,0 @@
-/*!
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-import React from "react";
-import { Button, Link } from "@chakra-ui/react";
-
-import { getMetaValue } from "src/utils";
-
-const root = getMetaValue("root");
-const url = getMetaValue("grid_url_no_root");
-
-const ResetRoot = () =>
-  root ? (
-    <Button
-      as={Link}
-      variant="outline"
-      href={url}
-      colorScheme="blue"
-      mx={2}
-      title="Reset root to show the whole DAG"
-    >
-      Reset Root
-    </Button>
-  ) : null;
-
-export default ResetRoot;
diff --git a/airflow/www/static/js/dag/nav/FilterBar.tsx 
b/airflow/www/static/js/dag/nav/FilterBar.tsx
index d442a67fb9..4b5e2f3be3 100644
--- a/airflow/www/static/js/dag/nav/FilterBar.tsx
+++ b/airflow/www/static/js/dag/nav/FilterBar.tsx
@@ -26,8 +26,7 @@ import AutoRefresh from "src/components/AutoRefresh";
 
 import { useTimezone } from "src/context/timezone";
 import { isoFormatWithoutTZ } from "src/datetime_utils";
-import useFilters from "../useFilters";
-import ResetRoot from "../grid/ResetRoot";
+import useFilters from "src/dag/useFilters";
 
 declare const filtersOptions: {
   dagStates: RunState[];
@@ -123,10 +122,7 @@ const FilterBar = () => {
           </Button>
         </Box>
       </Flex>
-      <Flex>
-        <AutoRefresh />
-        <ResetRoot />
-      </Flex>
+      <AutoRefresh />
     </Flex>
   );
 };
diff --git a/airflow/www/static/js/dag/useFilters.test.tsx 
b/airflow/www/static/js/dag/useFilters.test.tsx
index a3da4ace59..8081dad20e 100644
--- a/airflow/www/static/js/dag/useFilters.test.tsx
+++ b/airflow/www/static/js/dag/useFilters.test.tsx
@@ -38,6 +38,7 @@ jest.useFakeTimers().setSystemTime(date);
 import useFilters, {
   FilterHookReturn,
   Filters,
+  FilterTasksProps,
   UtilFunctions,
 } from "./useFilters";
 
@@ -48,13 +49,24 @@ describe("Test useFilters hook", () => {
       { wrapper: RouterWrapper }
     );
     const {
-      filters: { baseDate, numRuns, runType, runState },
+      filters: {
+        baseDate,
+        numRuns,
+        runType,
+        runState,
+        root,
+        filterUpstream,
+        filterDownstream,
+      },
     } = result.current;
 
     expect(baseDate).toBe(date.toISOString());
     expect(numRuns).toBe(global.defaultDagRunDisplayNumber.toString());
     expect(runType).toBeNull();
     expect(runState).toBeNull();
+    expect(root).toBeUndefined();
+    expect(filterUpstream).toBeUndefined();
+    expect(filterDownstream).toBeUndefined();
   });
 
   test.each([
@@ -85,7 +97,7 @@ describe("Test useFilters hook", () => {
     );
 
     await act(async () => {
-      result.current[fnName](paramValue);
+      result.current[fnName](paramValue as "string" & FilterTasksProps);
     });
 
     expect(result.current.filters[paramName]).toBe(paramValue);
@@ -105,4 +117,57 @@ describe("Test useFilters hook", () => {
       expect(result.current.filters[paramName]).toBeNull();
     }
   });
+
+  test("Test onFilterTasksChange ", async () => {
+    const { result } = renderHook<FilterHookReturn, undefined>(
+      () => useFilters(),
+      { wrapper: RouterWrapper }
+    );
+
+    await act(async () => {
+      result.current.onFilterTasksChange({
+        root: "test",
+        filterUpstream: true,
+        filterDownstream: false,
+      });
+    });
+
+    expect(result.current.filters.root).toBe("test");
+    expect(result.current.filters.filterUpstream).toBe(true);
+    expect(result.current.filters.filterDownstream).toBe(false);
+
+    // sending same info clears filters
+    await act(async () => {
+      result.current.onFilterTasksChange({
+        root: "test",
+        filterUpstream: true,
+        filterDownstream: false,
+      });
+    });
+
+    expect(result.current.filters.root).toBeUndefined();
+    expect(result.current.filters.filterUpstream).toBeUndefined();
+    expect(result.current.filters.filterDownstream).toBeUndefined();
+
+    await act(async () => {
+      result.current.onFilterTasksChange({
+        root: "test",
+        filterUpstream: true,
+        filterDownstream: false,
+      });
+    });
+
+    expect(result.current.filters.root).toBe("test");
+    expect(result.current.filters.filterUpstream).toBe(true);
+    expect(result.current.filters.filterDownstream).toBe(false);
+
+    // clearFilters
+    await act(async () => {
+      result.current.resetRoot();
+    });
+
+    expect(result.current.filters.root).toBeUndefined();
+    expect(result.current.filters.filterUpstream).toBeUndefined();
+    expect(result.current.filters.filterDownstream).toBeUndefined();
+  });
 });
diff --git a/airflow/www/static/js/dag/useFilters.tsx 
b/airflow/www/static/js/dag/useFilters.tsx
index 0ed728d04a..26f97dca54 100644
--- a/airflow/www/static/js/dag/useFilters.tsx
+++ b/airflow/www/static/js/dag/useFilters.tsx
@@ -25,18 +25,29 @@ import URLSearchParamsWrapper from 
"src/utils/URLSearchParamWrapper";
 declare const defaultDagRunDisplayNumber: number;
 
 export interface Filters {
+  root: string | undefined;
+  filterUpstream: boolean | undefined;
+  filterDownstream: boolean | undefined;
   baseDate: string | null;
   numRuns: string | null;
   runType: string | null;
   runState: string | null;
 }
 
+export interface FilterTasksProps {
+  root: string;
+  filterUpstream: boolean;
+  filterDownstream: boolean;
+}
+
 export interface UtilFunctions {
   onBaseDateChange: (value: string) => void;
   onNumRunsChange: (value: string) => void;
   onRunTypeChange: (value: string) => void;
   onRunStateChange: (value: string) => void;
+  onFilterTasksChange: (args: FilterTasksProps) => void;
   clearFilters: () => void;
+  resetRoot: () => void;
 }
 
 export interface FilterHookReturn extends UtilFunctions {
@@ -49,6 +60,10 @@ export const NUM_RUNS_PARAM = "num_runs";
 export const RUN_TYPE_PARAM = "run_type";
 export const RUN_STATE_PARAM = "run_state";
 
+export const ROOT_PARAM = "root";
+export const FILTER_UPSTREAM_PARAM = "filter_upstream";
+export const FILTER_DOWNSTREAM_PARAM = "filter_downstream";
+
 const date = new Date();
 date.setMilliseconds(0);
 
@@ -57,6 +72,14 @@ export const now = date.toISOString();
 const useFilters = (): FilterHookReturn => {
   const [searchParams, setSearchParams] = useSearchParams();
 
+  const root = searchParams.get(ROOT_PARAM) || undefined;
+  const filterUpstream = root
+    ? searchParams.get(FILTER_UPSTREAM_PARAM) === "true"
+    : undefined;
+  const filterDownstream = root
+    ? searchParams.get(FILTER_DOWNSTREAM_PARAM) === "true"
+    : undefined;
+
   const baseDate = searchParams.get(BASE_DATE_PARAM) || now;
   const numRuns =
     searchParams.get(NUM_RUNS_PARAM) || defaultDagRunDisplayNumber.toString();
@@ -83,6 +106,30 @@ const useFilters = (): FilterHookReturn => {
   const onRunTypeChange = makeOnChangeFn(RUN_TYPE_PARAM);
   const onRunStateChange = makeOnChangeFn(RUN_STATE_PARAM);
 
+  const onFilterTasksChange = ({
+    root: newRoot,
+    filterUpstream: newUpstream,
+    filterDownstream: newDownstream,
+  }: FilterTasksProps) => {
+    const params = new URLSearchParamsWrapper(searchParams);
+
+    if (
+      root === newRoot &&
+      newUpstream === filterUpstream &&
+      newDownstream === filterDownstream
+    ) {
+      params.delete(ROOT_PARAM);
+      params.delete(FILTER_UPSTREAM_PARAM);
+      params.delete(FILTER_DOWNSTREAM_PARAM);
+    } else {
+      params.set(ROOT_PARAM, newRoot);
+      params.set(FILTER_UPSTREAM_PARAM, newUpstream.toString());
+      params.set(FILTER_DOWNSTREAM_PARAM, newDownstream.toString());
+    }
+
+    setSearchParams(params);
+  };
+
   const clearFilters = () => {
     searchParams.delete(BASE_DATE_PARAM);
     searchParams.delete(NUM_RUNS_PARAM);
@@ -91,8 +138,18 @@ const useFilters = (): FilterHookReturn => {
     setSearchParams(searchParams);
   };
 
+  const resetRoot = () => {
+    searchParams.delete(ROOT_PARAM);
+    searchParams.delete(FILTER_UPSTREAM_PARAM);
+    searchParams.delete(FILTER_DOWNSTREAM_PARAM);
+    setSearchParams(searchParams);
+  };
+
   return {
     filters: {
+      root,
+      filterUpstream,
+      filterDownstream,
       baseDate,
       numRuns,
       runType,
@@ -102,7 +159,9 @@ const useFilters = (): FilterHookReturn => {
     onNumRunsChange,
     onRunTypeChange,
     onRunStateChange,
+    onFilterTasksChange,
     clearFilters,
+    resetRoot,
   };
 };
 
diff --git a/airflow/www/static/js/utils/graph.ts 
b/airflow/www/static/js/utils/graph.ts
index 8c8da293ce..24d3bb3750 100644
--- a/airflow/www/static/js/utils/graph.ts
+++ b/airflow/www/static/js/utils/graph.ts
@@ -22,6 +22,7 @@ import ELK, { ElkExtendedEdge, ElkShape } from "elkjs";
 import type { DepNode } from "src/types";
 import type { NodeType } from "src/datasets/Graph/Node";
 import { useQuery } from "react-query";
+import useFilters from "src/dag/useFilters";
 
 interface GenerateProps {
   nodes: DepNode[];
@@ -194,9 +195,21 @@ export const useGraphLayout = ({
   nodes,
   openGroupIds,
   arrange = "LR",
-}: LayoutProps) =>
-  useQuery(
-    ["graphLayout", !!nodes?.children, openGroupIds, arrange],
+}: LayoutProps) => {
+  const {
+    filters: { root, filterDownstream, filterUpstream },
+  } = useFilters();
+
+  return useQuery(
+    [
+      "graphLayout",
+      !!nodes?.children,
+      openGroupIds,
+      arrange,
+      root,
+      filterUpstream,
+      filterDownstream,
+    ],
     async () => {
       const font = `bold ${16}px ${
         window.getComputedStyle(document.body).fontFamily
@@ -214,3 +227,4 @@ export const useGraphLayout = ({
       return data as Graph;
     }
   );
+};
diff --git a/airflow/www/views.py b/airflow/www/views.py
index c0125ed272..8b7e53f748 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2889,8 +2889,8 @@ class Airflow(AirflowBaseView):
 
         root = request.args.get("root")
         if root:
-            filter_upstream = True if request.args.get("filter_upstream") == 
"true" else False
-            filter_downstream = True if request.args.get("filter_downstream") 
== "true" else False
+            filter_upstream = request.args.get("filter_upstream") == "true"
+            filter_downstream = request.args.get("filter_downstream") == "true"
             dag = dag.partial_subset(
                 task_ids_or_regex=root, include_upstream=filter_upstream, 
include_downstream=filter_downstream
             )
@@ -3611,7 +3611,11 @@ class Airflow(AirflowBaseView):
 
         root = request.args.get("root")
         if root:
-            dag = dag.partial_subset(task_ids_or_regex=root, 
include_downstream=False, include_upstream=True)
+            filter_upstream = request.args.get("filter_upstream") == "true"
+            filter_downstream = request.args.get("filter_downstream") == "true"
+            dag = dag.partial_subset(
+                task_ids_or_regex=root, include_upstream=filter_upstream, 
include_downstream=filter_downstream
+            )
 
         num_runs = request.args.get("num_runs", type=int)
         if num_runs is None:

Reply via email to