kaxil commented on code in PR #63501:
URL: https://github.com/apache/airflow/pull/63501#discussion_r2934042352
##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -66,6 +68,13 @@
"required": ["sql"],
}
+# DataFusion python bindings don't expose any native exception types, it uses
rust exceptions.
+# So we have to rely on error message parsing with regex.
+_RETRYABLE_QUERY_ERROR_PATTERNS = (
+ re.compile(r"""column\s+['"][^'"]+['"]\s+not\s+found""", re.IGNORECASE),
Review Comment:
The regex patterns require the identifier to be quoted (`'name'` or
`"name"`). If a future DataFusion version changes its error format to unquoted
identifiers (e.g., `column name not found`), the retry won't trigger and the
error will silently become a JSON response instead. Since DataFusion is a Rust
library whose error messages aren't part of a stable API, you might want to
also match the unquoted form. Something like
`column\s+(?:['"][^'"]+['"]|\w+)\s+not\s+found` would cover both.
##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -285,9 +288,75 @@ def
test_query_execution_exception_returns_error_json(self):
)
)
data = json.loads(result)
- assert "column x not found" in data["error"]
+ assert "execution failed" in data["error"]
assert data["query"] == "SELECT x FROM t"
+ def test_column_not_found_raises_model_retry(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv",
delete=False) as f:
+ f.write("pickup_date,id\n2024-01-01,1\n")
+ csv_path = f.name
Review Comment:
nit: Consider using pytest's `tmp_path` fixture instead of manual
`tempfile.NamedTemporaryFile` + `os.unlink` in a try/finally. `tmp_path`
handles cleanup automatically and is the standard pytest pattern.
##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -285,9 +288,75 @@ def
test_query_execution_exception_returns_error_json(self):
)
)
data = json.loads(result)
- assert "column x not found" in data["error"]
+ assert "execution failed" in data["error"]
assert data["query"] == "SELECT x FROM t"
+ def test_column_not_found_raises_model_retry(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv",
delete=False) as f:
+ f.write("pickup_date,id\n2024-01-01,1\n")
+ csv_path = f.name
+
+ try:
+ ts = DataFusionToolset(
+ [
+ DataSourceConfig(
+ conn_id="",
+ table_name="trips",
+ uri=f"file://{csv_path}",
+ format="csv",
+ )
+ ]
+ )
+
+ with pytest.raises(ModelRetry) as exc_info:
+ asyncio.run(
+ ts.call_tool(
+ "query",
+ {"sql": "SELECT name FROM trips"},
+ ctx=MagicMock(spec=RunContext),
+ tool=MagicMock(spec=ToolsetTool),
+ )
+ )
+
+ assert "column 'name' not found" in exc_info.value.message
+ assert "get_schema" in exc_info.value.message
+ assert "list_tables" in exc_info.value.message
+ finally:
+ os.unlink(csv_path)
+
+ def test_table_not_found_raises_model_retry(self):
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv",
delete=False) as f:
Review Comment:
nit: `test_table_not_found_raises_model_retry` is nearly identical to
`test_column_not_found_raises_model_retry`. You could collapse these into a
single `@pytest.mark.parametrize` test with `(sql, expected_substring)` pairs.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]