This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d4fe325f84 Adding params. to create_auto_ml_forecasting_training_job
in AutoMl hook (#39767)
d4fe325f84 is described below
commit d4fe325f8489aa19858b68ea42b71d99e80410a4
Author: Siddesh M G <[email protected]>
AuthorDate: Sun May 26 19:08:18 2024 +0530
Adding params. to create_auto_ml_forecasting_training_job in AutoMl hook
(#39767)
* Update auto_ml.py
Added window_stride_length & window_max_count
* Update auto_ml.py
* Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Co-authored-by: Shahar Epstein <[email protected]>
* Update test_vertex_ai.py
* Update test_vertex_ai.py
* Update auto_ml.py
* Update airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
Co-authored-by: Josh Fell <[email protected]>
* Update airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
Co-authored-by: Andrey Anshin <[email protected]>
* Update auto_ml.py
* Update test_vertex_ai.py
---------
Co-authored-by: Shahar Epstein <[email protected]>
Co-authored-by: Josh Fell <[email protected]>
Co-authored-by: Andrey Anshin <[email protected]>
---
airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py | 8 ++++++++
airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py | 6 ++++++
tests/providers/google/cloud/operators/test_vertex_ai.py | 4 ++++
3 files changed, 18 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
index 5c6d56529d..b1ad7d1a07 100644
--- a/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py
@@ -551,6 +551,8 @@ class AutoMLHook(GoogleBaseHook):
is_default_version: bool | None = None,
model_version_aliases: list[str] | None = None,
model_version_description: str | None = None,
+ window_stride_length: int | None = None,
+ window_max_count: int | None = None,
) -> tuple[models.Model | None, str]:
"""
Create an AutoML Forecasting Training Job.
@@ -703,6 +705,10 @@ class AutoMLHook(GoogleBaseHook):
:param sync: Whether to execute this method synchronously. If False,
this method will be executed in
concurrent Future and any downstream object will be immediately
returned and synced when the
Future has completed.
+ :param window_stride_length: Optional. Step length used to generate
input examples. Every
+ ``window_stride_length`` rows will be used to generate a sliding
window.
+ :param window_max_count: Optional. Number of rows that should be used
to generate input examples. If the
+ total row count is larger than this number, the input data will be
randomly sampled to hit the count.
"""
if column_transformations:
warnings.warn(
@@ -758,6 +764,8 @@ class AutoMLHook(GoogleBaseHook):
is_default_version=is_default_version,
model_version_aliases=model_version_aliases,
model_version_description=model_version_description,
+ window_stride_length=window_stride_length,
+ window_max_count=window_max_count,
)
training_id = self.extract_training_id(self._job.resource_name)
if model:
diff --git a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index 1475232012..7e3d8bb083 100644
--- a/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++ b/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -138,6 +138,8 @@ class
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
region: str,
impersonation_chain: str | Sequence[str] | None = None,
parent_model: str | None = None,
+ window_stride_length: int | None = None,
+ window_max_count: int | None = None,
**kwargs,
) -> None:
super().__init__(
@@ -170,6 +172,8 @@ class
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
self.quantiles = quantiles
self.validation_options = validation_options
self.budget_milli_node_hours = budget_milli_node_hours
+ self.window_stride_length = window_stride_length
+ self.window_max_count = window_max_count
def execute(self, context: Context):
self.hook = AutoMLHook(
@@ -220,6 +224,8 @@ class
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
model_display_name=self.model_display_name,
model_labels=self.model_labels,
sync=self.sync,
+ window_stride_length=self.window_stride_length,
+ window_max_count=self.window_max_count,
)
if model:
diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py
b/tests/providers/google/cloud/operators/test_vertex_ai.py
index 3f8649f588..4b8264d615 100644
--- a/tests/providers/google/cloud/operators/test_vertex_ai.py
+++ b/tests/providers/google/cloud/operators/test_vertex_ai.py
@@ -1340,6 +1340,8 @@ class
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
is_default_version=None,
model_version_aliases=None,
model_version_description=None,
+ window_stride_length=None,
+ window_max_count=None,
)
@mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset")
@@ -1405,6 +1407,8 @@ class
TestVertexAICreateAutoMLForecastingTrainingJobOperator:
is_default_version=None,
model_version_aliases=None,
model_version_description=None,
+ window_stride_length=None,
+ window_max_count=None,
)