uranusjr commented on a change in pull request #17324:
URL: https://github.com/apache/airflow/pull/17324#discussion_r679648318



##########
File path: tests/conftest.py
##########
@@ -444,33 +445,44 @@ def __exit__(self, type, value, traceback):
             dag.__exit__(type, value, traceback)
             if type is None:
                 dag.clear()
-                self.dag_run = dag.create_dagrun(
-                    run_id=self.kwargs.get("run_id", "test"),
-                    state=self.kwargs.get('state', State.RUNNING),
-                    execution_date=self.kwargs.get('execution_date', 
self.kwargs['start_date']),
-                    start_date=self.kwargs['start_date'],
-                )
+
+        def _set_default_args(self, kwargs, defaults):
+            for k, v in defaults.items():
+                if k not in kwargs:
+                    kwargs.setdefault(k, v)
+
+        @provide_session
+        def make_dagmodel(self, session=None, **kwargs):
+            dag = self.dag
+            defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, 
is_active=True)
+            self._set_default_args(kwargs, defaults)
+            dag_model = DagModel(**kwargs)
+            session.add(dag_model)
+            session.flush()
+            return dag_model
+
+        def create_dagrun(self, **kwargs):
+            dag = self.dag
+            defaults = dict(
+                run_id='test',
+                state=State.RUNNING,
+                execution_date=self.start_date,
+                start_date=self.start_date,
+            )
+            self._set_default_args(kwargs, defaults)
+            self.dag_run = dag.create_dagrun(**kwargs)
+            return self.dag_run
 
         def __call__(self, dag_id='test_dag', **kwargs):
             self.kwargs = kwargs
-            if "start_date" not in kwargs:
+            self.start_date = self.kwargs.get('start_date', None)
+            if not self.start_date:
                 if hasattr(request.module, 'DEFAULT_DATE'):
-                    kwargs['start_date'] = getattr(request.module, 
'DEFAULT_DATE')
+                    self.start_date = getattr(request.module, 'DEFAULT_DATE')
                 else:
-                    kwargs['start_date'] = DEFAULT_DATE
-            dagrun_fields_not_in_dag = [
-                'state',
-                'execution_date',
-                'run_type',
-                'queued_at',
-                "run_id",
-                "creating_job_id",
-                "external_trigger",
-                "last_scheduling_decision",
-                "dag_hash",
-            ]
-            kwargs = {k: v for k, v in kwargs.items() if k not in 
dagrun_fields_not_in_dag}
-            self.dag = DAG(dag_id, **kwargs)
+                    self.start_date = DEFAULT_DATE
+            self.kwargs.update(dict(start_date=self.start_date))

Review comment:
       ```suggestion
               self.kwargs['start_date'] = self.start_date
   ```
   
   But if `start_date` is unconditionally overridden, maybe we should instead…

##########
File path: tests/conftest.py
##########
@@ -444,33 +445,44 @@ def __exit__(self, type, value, traceback):
             dag.__exit__(type, value, traceback)
             if type is None:
                 dag.clear()
-                self.dag_run = dag.create_dagrun(
-                    run_id=self.kwargs.get("run_id", "test"),
-                    state=self.kwargs.get('state', State.RUNNING),
-                    execution_date=self.kwargs.get('execution_date', 
self.kwargs['start_date']),
-                    start_date=self.kwargs['start_date'],
-                )
+
+        def _set_default_args(self, kwargs, defaults):
+            for k, v in defaults.items():
+                if k not in kwargs:
+                    kwargs.setdefault(k, v)
+
+        @provide_session
+        def make_dagmodel(self, session=None, **kwargs):
+            dag = self.dag
+            defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, 
is_active=True)
+            self._set_default_args(kwargs, defaults)

Review comment:
       I think this can be instead
   
   ```suggestion
               kwargs = {**defaults, **kwargs}
   ```
   
   [PEP 448](https://www.python.org/dev/peps/pep-0448/#specification):
   
   > The keys in a dictionary remain in a right-to-left priority order, so 
`{**{'a': 1}, 'a': 2, **{'a': 3}}` evaluates to `{'a': 3}`. There is no 
restriction on the number or position of unpackings.

##########
File path: tests/conftest.py
##########
@@ -444,33 +445,44 @@ def __exit__(self, type, value, traceback):
             dag.__exit__(type, value, traceback)
             if type is None:
                 dag.clear()
-                self.dag_run = dag.create_dagrun(
-                    run_id=self.kwargs.get("run_id", "test"),
-                    state=self.kwargs.get('state', State.RUNNING),
-                    execution_date=self.kwargs.get('execution_date', 
self.kwargs['start_date']),
-                    start_date=self.kwargs['start_date'],
-                )
+
+        def _set_default_args(self, kwargs, defaults):
+            for k, v in defaults.items():
+                if k not in kwargs:
+                    kwargs.setdefault(k, v)
+
+        @provide_session
+        def make_dagmodel(self, session=None, **kwargs):
+            dag = self.dag
+            defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, 
is_active=True)
+            self._set_default_args(kwargs, defaults)
+            dag_model = DagModel(**kwargs)
+            session.add(dag_model)
+            session.flush()
+            return dag_model
+
+        def create_dagrun(self, **kwargs):
+            dag = self.dag
+            defaults = dict(
+                run_id='test',
+                state=State.RUNNING,
+                execution_date=self.start_date,
+                start_date=self.start_date,
+            )
+            self._set_default_args(kwargs, defaults)
+            self.dag_run = dag.create_dagrun(**kwargs)
+            return self.dag_run
 
         def __call__(self, dag_id='test_dag', **kwargs):
             self.kwargs = kwargs
-            if "start_date" not in kwargs:
+            self.start_date = self.kwargs.get('start_date', None)
+            if not self.start_date:
                 if hasattr(request.module, 'DEFAULT_DATE'):
-                    kwargs['start_date'] = getattr(request.module, 
'DEFAULT_DATE')
+                    self.start_date = getattr(request.module, 'DEFAULT_DATE')
                 else:
-                    kwargs['start_date'] = DEFAULT_DATE
-            dagrun_fields_not_in_dag = [
-                'state',
-                'execution_date',
-                'run_type',
-                'queued_at',
-                "run_id",
-                "creating_job_id",
-                "external_trigger",
-                "last_scheduling_decision",
-                "dag_hash",
-            ]
-            kwargs = {k: v for k, v in kwargs.items() if k not in 
dagrun_fields_not_in_dag}
-            self.dag = DAG(dag_id, **kwargs)
+                    self.start_date = DEFAULT_DATE
+            self.kwargs.update(dict(start_date=self.start_date))
+            self.dag = DAG(dag_id, **self.kwargs)

Review comment:
       ```suggestion
               self.dag = DAG(dag_id, start_date=self.start_date, **self.kwargs)
   ```
   This makes the error clearer if the caller accidentally passes `start_date` 
into `dag_maker`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to