ephraimbuddy commented on code in PR #26165:
URL: https://github.com/apache/airflow/pull/26165#discussion_r1001845249
##########
tests/api_connexion/endpoints/test_task_instance_endpoint.py:
##########
@@ -1582,3 +1582,226 @@ def
test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se
)
assert response.status_code == 400
assert response.json['detail'] == expected
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+ ENDPOINT_URL = (
+
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+ )
+
+ @mock.patch('airflow.models.dag.DAG.set_task_instance_state')
+ def test_should_call_mocked_api(self, mock_set_task_instance_state,
session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+ mock_set_task_instance_state.return_value =
session.query(TaskInstance).get(
+ {
+ "task_id": "print_the_context",
+ "dag_id": "example_python_operator",
+ "run_id": "TEST_DAG_RUN_ID",
+ "map_index": -1,
+ }
+ )
+ response = self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": False,
+ "new_state": NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+ assert response.json == {
+ 'dag_id': 'example_python_operator',
+ 'dag_run_id': 'TEST_DAG_RUN_ID',
+ 'execution_date': '2020-01-01T00:00:00+00:00',
+ 'task_id': 'print_the_context',
+ }
+
+ mock_set_task_instance_state.assert_called_once_with(
+ task_id="print_the_context",
+ run_id="TEST_DAG_RUN_ID",
+ map_indexes=[-1],
+ state=NEW_STATE,
+ commit=True,
+ session=session,
+ )
+
+ @mock.patch('airflow.models.dag.DAG.set_task_instance_state')
+ def test_should_not_call_mocked_api_for_dry_run(self,
mock_set_task_instance_state, session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+ mock_set_task_instance_state.return_value =
session.query(TaskInstance).get(
+ {
+ "task_id": "print_the_context",
+ "dag_id": "example_python_operator",
+ "run_id": "TEST_DAG_RUN_ID",
+ "map_index": -1,
+ }
+ )
+ response = self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": True,
+ "new_state": NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+ print(response.status_code)
+ assert response.json == {
+ 'dag_id': 'example_python_operator',
+ 'dag_run_id': 'TEST_DAG_RUN_ID',
+ 'execution_date': '2020-01-01T00:00:00+00:00',
+ 'task_id': 'print_the_context',
+ }
+
+ mock_set_task_instance_state.assert_not_called()
+
+ def test_should_update_task_instance_state(self, session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+
+ self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": False,
+ "new_state": NEW_STATE,
+ },
+ )
+
+ response2 = self.client.get(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={},
+ )
+ assert response2.status_code == 200
+ assert response2.json["state"] == NEW_STATE
+
+ def test_should_update_mapped_task_instance_state(self, session):
+ tis = self.create_task_instances(session)
+ session.query()
Review Comment:
Why did we make this empty query?
##########
tests/api_connexion/endpoints/test_task_instance_endpoint.py:
##########
@@ -1582,3 +1582,226 @@ def
test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se
)
assert response.status_code == 400
assert response.json['detail'] == expected
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+ ENDPOINT_URL = (
+
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+ )
+
+ @mock.patch('airflow.models.dag.DAG.set_task_instance_state')
+ def test_should_call_mocked_api(self, mock_set_task_instance_state,
session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+ mock_set_task_instance_state.return_value =
session.query(TaskInstance).get(
+ {
+ "task_id": "print_the_context",
+ "dag_id": "example_python_operator",
+ "run_id": "TEST_DAG_RUN_ID",
+ "map_index": -1,
+ }
+ )
+ response = self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": False,
+ "new_state": NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+ assert response.json == {
+ 'dag_id': 'example_python_operator',
+ 'dag_run_id': 'TEST_DAG_RUN_ID',
+ 'execution_date': '2020-01-01T00:00:00+00:00',
+ 'task_id': 'print_the_context',
+ }
+
+ mock_set_task_instance_state.assert_called_once_with(
+ task_id="print_the_context",
+ run_id="TEST_DAG_RUN_ID",
+ map_indexes=[-1],
+ state=NEW_STATE,
+ commit=True,
+ session=session,
+ )
+
+ @mock.patch('airflow.models.dag.DAG.set_task_instance_state')
+ def test_should_not_call_mocked_api_for_dry_run(self,
mock_set_task_instance_state, session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+ mock_set_task_instance_state.return_value =
session.query(TaskInstance).get(
+ {
+ "task_id": "print_the_context",
+ "dag_id": "example_python_operator",
+ "run_id": "TEST_DAG_RUN_ID",
+ "map_index": -1,
+ }
+ )
+ response = self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": True,
+ "new_state": NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+ print(response.status_code)
+ assert response.json == {
+ 'dag_id': 'example_python_operator',
+ 'dag_run_id': 'TEST_DAG_RUN_ID',
+ 'execution_date': '2020-01-01T00:00:00+00:00',
+ 'task_id': 'print_the_context',
+ }
+
+ mock_set_task_instance_state.assert_not_called()
+
+ def test_should_update_task_instance_state(self, session):
+ self.create_task_instances(session)
+
+ NEW_STATE = "failed"
+
+ self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": False,
+ "new_state": NEW_STATE,
+ },
+ )
+
+ response2 = self.client.get(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={},
+ )
+ assert response2.status_code == 200
+ assert response2.json["state"] == NEW_STATE
+
+ def test_should_update_mapped_task_instance_state(self, session):
+ tis = self.create_task_instances(session)
+ session.query()
+ ti = tis[0]
+ ti.map_index = 1
+ rendered_fields = RTIF(ti, render_templates=False)
+ session.add(rendered_fields)
+ session.commit()
+
+ NEW_STATE = "failed"
+
+ self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={'REMOTE_USER': "test"},
+ json={
+ "dry_run": False,
+ "map_index": 1,
+ "new_state": NEW_STATE,
+ },
+ )
+
+ response2 = self.client.get(
+ f"{self.ENDPOINT_URL}/1",
+ environ_overrides={'REMOTE_USER': "test"},
+ json={},
+ )
Review Comment:
Is this second request necessary? If it is, what's the meaning of the `/1`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]