omkar-foss commented on code in PR #44223:
URL: https://github.com/apache/airflow/pull/44223#discussion_r1857267249


##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -2318,3 +2319,542 @@ def test_raises_404_for_non_existent_dag(self, 
test_client):
         )
         assert response.status_code == 404
         assert "DAG non-existent-dag not found" in response.text
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+    ENDPOINT_URL = (
+        
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+    )
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_call_mocked_api(self, mock_set_task_instance_state, 
test_client, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+        mock_set_task_instance_state.return_value = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == "example_python_operator",
+                TaskInstance.task_id == "print_the_context",
+                TaskInstance.run_id == "TEST_DAG_RUN_ID",
+                TaskInstance.map_index == -1,
+            )
+        ).one_or_none()
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == {
+            "dag_id": "example_python_operator",
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "logical_date": "2020-01-01T00:00:00Z",
+            "task_id": "print_the_context",
+            "duration": 10000.0,
+            "end_date": "2020-01-03T00:00:00Z",
+            "executor": None,
+            "executor_config": "{}",
+            "hostname": "",
+            "id": mock.ANY,
+            "map_index": -1,
+            "max_tries": 0,
+            "note": "placeholder-note",
+            "operator": "PythonOperator",
+            "pid": 100,
+            "pool": "default_pool",
+            "pool_slots": 1,
+            "priority_weight": 9,
+            "queue": "default_queue",
+            "queued_when": None,
+            "start_date": "2020-01-02T00:00:00Z",
+            "state": "running",
+            "task_display_name": "print_the_context",
+            "try_number": 0,
+            "unixname": getuser(),
+            "rendered_fields": {},
+            "rendered_map_index": None,
+            "trigger": None,
+            "triggerer_job": None,
+        }
+
+        mock_set_task_instance_state.assert_called_once()

Review Comment:
   Done



##########
tests/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -2318,3 +2319,542 @@ def test_raises_404_for_non_existent_dag(self, 
test_client):
         )
         assert response.status_code == 404
         assert "DAG non-existent-dag not found" in response.text
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+    ENDPOINT_URL = (
+        
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+    )
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_call_mocked_api(self, mock_set_task_instance_state, 
test_client, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+        mock_set_task_instance_state.return_value = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == "example_python_operator",
+                TaskInstance.task_id == "print_the_context",
+                TaskInstance.run_id == "TEST_DAG_RUN_ID",
+                TaskInstance.map_index == -1,
+            )
+        ).one_or_none()
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == {
+            "dag_id": "example_python_operator",
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "logical_date": "2020-01-01T00:00:00Z",
+            "task_id": "print_the_context",
+            "duration": 10000.0,
+            "end_date": "2020-01-03T00:00:00Z",
+            "executor": None,
+            "executor_config": "{}",
+            "hostname": "",
+            "id": mock.ANY,
+            "map_index": -1,
+            "max_tries": 0,
+            "note": "placeholder-note",
+            "operator": "PythonOperator",
+            "pid": 100,
+            "pool": "default_pool",
+            "pool_slots": 1,
+            "priority_weight": 9,
+            "queue": "default_queue",
+            "queued_when": None,
+            "start_date": "2020-01-02T00:00:00Z",
+            "state": "running",
+            "task_display_name": "print_the_context",
+            "try_number": 0,
+            "unixname": getuser(),
+            "rendered_fields": {},
+            "rendered_map_index": None,
+            "trigger": None,
+            "triggerer_job": None,
+        }
+
+        mock_set_task_instance_state.assert_called_once()
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_not_call_mocked_api_for_dry_run(self, 
mock_set_task_instance_state, test_client, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+        mock_set_task_instance_state.return_value = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == "example_python_operator",
+                TaskInstance.task_id == "print_the_context",
+                TaskInstance.run_id == "TEST_DAG_RUN_ID",
+                TaskInstance.map_index == -1,
+            )
+        ).one_or_none()
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": True,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == {
+            "dag_id": "example_python_operator",
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "logical_date": "2020-01-01T00:00:00Z",
+            "task_id": "print_the_context",
+            "duration": 10000.0,
+            "end_date": "2020-01-03T00:00:00Z",
+            "executor": None,
+            "executor_config": "{}",
+            "hostname": "",
+            "id": mock.ANY,
+            "map_index": -1,
+            "max_tries": 0,
+            "note": "placeholder-note",
+            "operator": "PythonOperator",
+            "pid": 100,
+            "pool": "default_pool",
+            "pool_slots": 1,
+            "priority_weight": 9,
+            "queue": "default_queue",
+            "queued_when": None,
+            "start_date": "2020-01-02T00:00:00Z",
+            "state": "running",
+            "task_display_name": "print_the_context",
+            "try_number": 0,
+            "unixname": getuser(),
+            "rendered_fields": {},
+            "rendered_map_index": None,
+            "trigger": None,
+            "triggerer_job": None,
+        }
+
+        mock_set_task_instance_state.assert_not_called()
+
+    def test_should_update_task_instance_state(self, test_client, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "failed"
+
+        test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+
+        response2 = test_client.get(self.ENDPOINT_URL)
+        assert response2.status_code == 200
+        assert response2.json()["state"] == NEW_STATE
+
+    def test_should_update_task_instance_state_default_dry_run_to_true(self, 
test_client, session):
+        self.create_task_instances(session)
+
+        NEW_STATE = "running"
+
+        test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "new_state": NEW_STATE,
+            },
+        )
+
+        response2 = test_client.get(self.ENDPOINT_URL)
+        assert response2.status_code == 200
+        assert response2.json()["state"] == NEW_STATE
+
+    def test_should_update_mapped_task_instance_state(self, test_client, 
session):
+        NEW_STATE = "failed"
+        map_index = 1
+        tis = self.create_task_instances(session)
+        ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, 
map_index=map_index)
+        ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
+        session.add(ti)
+        session.commit()
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/{map_index}",
+            json={
+                "dry_run": False,
+                "new_state": NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+
+        response2 = test_client.get(f"{self.ENDPOINT_URL}/{map_index}")
+        assert response2.status_code == 200
+        assert response2.json()["state"] == NEW_STATE
+
+    @pytest.mark.parametrize(
+        "error, code, payload",
+        [
+            [
+                (
+                    "Task Instance not found for 
dag_id=example_python_operator"
+                    ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context"
+                ),
+                404,
+                {
+                    "dry_run": True,
+                    "new_state": "failed",
+                },
+            ]
+        ],
+    )
+    def test_should_handle_errors(self, error, code, payload, test_client, 
session):
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json=payload,
+        )
+        assert response.status_code == code
+        assert response.json()["detail"] == error
+
+    def test_should_200_for_unknown_fields(self, test_client, session):
+        self.create_task_instances(session)
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dryrun": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 200
+
+    def test_should_raise_404_for_non_existent_dag(self, test_client):
+        response = test_client.patch(
+            
"/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
+            json={
+                "dry_run": False,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+        assert response.json() == {"detail": "DAG non-existent-dag not found"}
+
+    def test_should_raise_404_for_non_existent_task_in_dag(self, test_client):
+        response = test_client.patch(
+            
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task",
+            json={
+                "dry_run": False,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+        assert response.json() == {
+            "detail": "Task 'non_existent_task' not found in DAG 
'example_python_operator'"
+        }
+
+    def test_should_raise_404_not_found_dag(self, test_client):
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+
+    def test_should_raise_404_not_found_task(self, test_client):
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "dry_run": True,
+                "new_state": "failed",
+            },
+        )
+        assert response.status_code == 404
+
+    @pytest.mark.parametrize(
+        "payload, expected",
+        [
+            (
+                {
+                    "dry_run": True,
+                    "new_state": "failede",
+                },
+                f"'failede' is not one of ['{State.SUCCESS}', 
'{State.FAILED}', '{State.SKIPPED}']",
+            ),
+            (
+                {
+                    "dry_run": True,
+                    "new_state": "queued",
+                },
+                f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', 
'{State.SKIPPED}']",
+            ),
+        ],
+    )
+    def test_should_raise_422_for_invalid_task_instance_state(self, payload, 
expected, test_client, session):
+        self.create_task_instances(session)
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json=payload,
+        )
+        assert response.status_code == 422
+        assert response.json() == {
+            "detail": [
+                {
+                    "type": "value_error",
+                    "loc": ["body", "new_state"],
+                    "msg": f"Value error, {expected}",
+                    "input": payload["new_state"],
+                    "ctx": {"error": {}},
+                }
+            ]
+        }
+
+    @pytest.mark.parametrize(
+        "new_state,expected_status_code,expected_json,set_ti_state_call_count",
+        [
+            (
+                "failed",
+                200,
+                {
+                    "dag_id": "example_python_operator",
+                    "dag_run_id": "TEST_DAG_RUN_ID",
+                    "logical_date": "2020-01-01T00:00:00Z",
+                    "task_id": "print_the_context",
+                    "duration": 10000.0,
+                    "end_date": "2020-01-03T00:00:00Z",
+                    "executor": None,
+                    "executor_config": "{}",
+                    "hostname": "",
+                    "id": mock.ANY,
+                    "map_index": -1,
+                    "max_tries": 0,
+                    "note": "placeholder-note",
+                    "operator": "PythonOperator",
+                    "pid": 100,
+                    "pool": "default_pool",
+                    "pool_slots": 1,
+                    "priority_weight": 9,
+                    "queue": "default_queue",
+                    "queued_when": None,
+                    "start_date": "2020-01-02T00:00:00Z",
+                    "state": "running",
+                    "task_display_name": "print_the_context",
+                    "try_number": 0,
+                    "unixname": getuser(),
+                    "rendered_fields": {},
+                    "rendered_map_index": None,
+                    "trigger": None,
+                    "triggerer_job": None,
+                },
+                1,
+            ),
+            (
+                None,
+                422,
+                {
+                    "detail": [
+                        {
+                            "type": "value_error",
+                            "loc": ["body", "new_state"],
+                            "msg": "Value error, 'new_state' should not be 
empty",
+                            "input": None,
+                            "ctx": {"error": {}},
+                        }
+                    ]
+                },
+                0,
+            ),
+        ],
+    )
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_update_mask_should_call_mocked_api(
+        self,
+        mock_set_ti_state,
+        test_client,
+        session,
+        new_state,
+        expected_status_code,
+        expected_json,
+        set_ti_state_call_count,
+    ):
+        self.create_task_instances(session)
+
+        mock_set_ti_state.return_value = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == "example_python_operator",
+                TaskInstance.task_id == "print_the_context",
+                TaskInstance.run_id == "TEST_DAG_RUN_ID",
+                TaskInstance.map_index == -1,
+            )
+        ).one_or_none()
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            params={"update_mask": "new_state"},
+            json={
+                "dry_run": False,
+                "new_state": new_state,
+            },
+        )
+        assert response.status_code == expected_status_code
+        assert response.json() == expected_json
+        assert mock_set_ti_state.call_count == set_ti_state_call_count
+
+    @pytest.mark.parametrize(
+        "new_note_value",
+        [
+            "My super cool TaskInstance note.",
+            None,
+        ],
+    )
+    def test_update_mask_set_note_should_respond_200(self, test_client, 
session, new_note_value):
+        self.create_task_instances(session)
+        response = test_client.patch(
+            
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
+            params={"update_mask": "note"},
+            json={"note": new_note_value},
+        )
+        assert response.status_code == 200, response.text
+        assert response.json() == {
+            "dag_id": "example_python_operator",
+            "duration": 10000.0,
+            "end_date": "2020-01-03T00:00:00Z",
+            "logical_date": "2020-01-01T00:00:00Z",
+            "id": mock.ANY,
+            "executor": None,
+            "executor_config": "{}",
+            "hostname": "",
+            "map_index": -1,
+            "max_tries": 0,
+            "note": new_note_value,
+            "operator": "PythonOperator",
+            "pid": 100,
+            "pool": "default_pool",
+            "pool_slots": 1,
+            "priority_weight": 9,
+            "queue": "default_queue",
+            "queued_when": None,
+            "start_date": "2020-01-02T00:00:00Z",
+            "state": "running",
+            "task_id": "print_the_context",
+            "task_display_name": "print_the_context",
+            "try_number": 0,
+            "unixname": getuser(),
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "rendered_fields": {},
+            "rendered_map_index": None,
+            "trigger": None,
+            "triggerer_job": None,
+        }
+
+    def test_set_note_should_respond_200(self, test_client, session):
+        self.create_task_instances(session)
+        new_note_value = "My super cool TaskInstance note."
+        response = test_client.patch(
+            
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
+            json={"note": new_note_value},
+        )
+        assert response.status_code == 200, response.text
+        assert response.json() == {
+            "dag_id": "example_python_operator",
+            "duration": 10000.0,
+            "end_date": "2020-01-03T00:00:00Z",
+            "logical_date": "2020-01-01T00:00:00Z",
+            "id": mock.ANY,
+            "executor": None,
+            "executor_config": "{}",
+            "hostname": "",
+            "map_index": -1,
+            "max_tries": 0,
+            "note": new_note_value,
+            "operator": "PythonOperator",
+            "pid": 100,
+            "pool": "default_pool",
+            "pool_slots": 1,
+            "priority_weight": 9,
+            "queue": "default_queue",
+            "queued_when": None,
+            "start_date": "2020-01-02T00:00:00Z",
+            "state": "running",
+            "task_id": "print_the_context",
+            "task_display_name": "print_the_context",
+            "try_number": 0,
+            "unixname": getuser(),
+            "dag_run_id": "TEST_DAG_RUN_ID",
+            "rendered_fields": {},
+            "rendered_map_index": None,
+            "trigger": None,
+            "triggerer_job": None,
+        }
+        # @TODO: Uncomment the 2 lines below when permissions and auth is in 
place.
+        # ti = session.scalars(select(TaskInstance).where(TaskInstance.task_id 
== "print_the_context")).one()
+        # assert ti.task_instance_note.user_id is not None

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to