This is an automated email from the ASF dual-hosted git repository.

ebenizzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hamilton.git


The following commit(s) were added to refs/heads/main by this push:
     new d601a92f Update extract fields (#1305)
d601a92f is described below

commit d601a92fa026b3aed76984d66c8a5c6eb95632f6
Author: Charles Swartz <[email protected]>
AuthorDate: Tue Jul 15 00:26:24 2025 -0400

    Update extract fields (#1305)
---
 docs/concepts/function-modifiers.rst       |  70 +++++-
 hamilton/function_modifiers/expanders.py   | 206 +++++++++++------
 tests/function_modifiers/test_expanders.py | 340 ++++++++++++++++++++++-------
 3 files changed, 468 insertions(+), 148 deletions(-)

diff --git a/docs/concepts/function-modifiers.rst 
b/docs/concepts/function-modifiers.rst
index 63b206eb..6bbed9eb 100644
--- a/docs/concepts/function-modifiers.rst
+++ b/docs/concepts/function-modifiers.rst
@@ -140,7 +140,7 @@ The ``@check_output`` function modifiers are applied on the 
**node output / func
 
 .. note::
 
-    In the future, validatation capabailities may be added to ``@schema``. For 
now, it's only added metadata.
+    In the future, validation capabilities may be added to ``@schema``. For 
now, it's only added metadata.
 
 @check_output*
 ~~~~~~~~~~~~~~
@@ -201,7 +201,7 @@ A good example is splitting a dataset into training, 
validation, and test splits
     from typing import Tuple
     from hamilton.function_modifiers import unpack_fields
 
-    @unpack_fields("X_train" "X_validation", "X_test")
+    @unpack_fields("X_train", "X_validation", "X_test")
     def dataset_splits(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray, 
np.ndarray]:
         """Randomly split data into train, validation, test"""
         X_train, X_validation, X_test = random_split(X)
@@ -216,14 +216,14 @@ Now, ``X_train``, ``X_validation``, and ``X_test`` are 
available to other nodes
 @extract_fields
 ~~~~~~~~~~~~~~~
 
-Additionally, we can extract fields from an output dictionary using 
``@extract_fields``. In this case, you must specify the dictionary keys and 
their types. The function must return a dictionary that contains, at a minimum, 
those keys specified in the decorator.
+Additionally, we can extract fields from an output dictionary using 
``@extract_fields``. The function must return a dictionary that contains, at a 
minimum, those keys specified in the decorator. In this case, you can specify a 
dictionary of fields and their types:
 
 .. code-block:: python
 
     from typing import Dict
     from hamilton.function_modifiers import extract_fields
 
-    @extract_fields(dict(  # don't forget the dictionary
+    @extract_fields(dict(  # fields specified as a dictionary
         X_train=np.ndarray,
         X_validation=np.ndarray,
         X_test=np.ndarray,
@@ -240,6 +240,68 @@ Additionally, we can extract fields from an output 
dictionary using ``@extract_f
 .. image:: ./_function-modifiers/extract_fields.png
     :height: 250px
 
+Or if you are using a generic dictionary, you can specify solely the field 
names.
+
+.. code-block:: python
+
+    from typing import Dict
+    from hamilton.function_modifiers import extract_fields
+
+    @extract_fields("X_train", "X_validation", "X_test")  # field names only
+    def dataset_splits(X: np.ndarray) -> Dict[str, np.ndarray]:  # generic dict
+        """Randomly split data into train, validation, test"""
+        X_train, X_validation, X_test = random_split(X)
+        return dict(
+            X_train=X_train,
+            X_validation=X_validation,
+            X_test=X_test,
+        )
+
+If you are using a `TypedDict`, you can specify the just field names.
+
+.. code-block:: python
+
+    from typing import TypedDict
+    from hamilton.function_modifiers import extract_fields
+
+    class DatasetSplits(TypedDict):
+        X_train: np.ndarray
+        X_validation: np.ndarray
+        X_test: np.ndarray
+
+    @extract_fields("X_train", "X_validation", "X_test")
+    def dataset_splits(X: np.ndarray) -> DatasetSplits:
+        """Randomly split data into train, validation, test"""
+        X_train, X_validation, X_test = random_split(X)
+        return dict(
+            X_train=X_train,
+            X_validation=X_validation,
+            X_test=X_test,
+        )
+
+
+Or you can leave the field names empty and extract all fields from the 
`TypedDict`.
+
+.. code-block:: python
+
+    from typing import TypedDict
+    from hamilton.function_modifiers import extract_fields
+
+    class DatasetSplits(TypedDict):
+        X_train: np.ndarray
+        X_validation: np.ndarray
+        X_test: np.ndarray
+
+    @extract_fields(DatasetSplits)  # field names only
+    def dataset_splits(X: np.ndarray) -> DatasetSplits:
+        """Randomly split data into train, validation, test"""
+        X_train, X_validation, X_test = random_split(X)
+        return dict(
+            X_train=X_train,
+            X_validation=X_validation,
+            X_test=X_test,
+        )
+
 
 Again, ``X_train``, ``X_validation``, and ``X_test`` are now available to 
other nodes, or you can query the ``dataset_splits`` node to retrieve all 
splits in a dictionary.
 
diff --git a/hamilton/function_modifiers/expanders.py 
b/hamilton/function_modifiers/expanders.py
index d04cf6ee..a74f2aa2 100644
--- a/hamilton/function_modifiers/expanders.py
+++ b/hamilton/function_modifiers/expanders.py
@@ -3,7 +3,7 @@ import dataclasses
 import functools
 import inspect
 import typing
-from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
+from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, 
Type, Union
 
 import typing_extensions
 import typing_inspect
@@ -699,6 +699,89 @@ class extract_columns(base.SingleNodeNodeTransformer):
         return output_nodes
 
 
+def _determine_fields_to_extract(
+    fields: Optional[Union[Dict[str, Any], List[str]]], output_type: Any
+) -> Dict[str, Any]:
+    """Determines which fields to extract based on user requested fields and 
the output type of
+    the return type of the function.
+
+    :param fields: Dict of fields to extract.
+    :param output_type: The output type of the node function.
+    :return: List of field types.
+    """
+
+    output_type_error = (
+        f"For extracting fields, the decorated function output type must be a 
`dict` or a "
+        f"`typing.Dict` with or without type parameters (i.e. `dict[str, int]` 
or "
+        f"`typing.Dict[str, int]`), not: {output_type}"
+    )
+
+    if output_type == dict or output_type == Dict:
+        # NOTE: typing_inspect.is_generic_type(typing.Dict) without type 
parameters returns True,
+        #       so we need to address the bare dictionaries first before 
generics.
+        if fields is None or not isinstance(fields, dict):
+            raise base.InvalidDecoratorException(
+                "When extracting fields from a function that returns a bare 
`dict` output without "
+                "type parameters, you must supply a `dict` mapping field names 
to types."
+            )
+    elif typing_inspect.is_generic_type(output_type):
+        base_type = typing_inspect.get_origin(output_type)
+        if base_type != dict and base_type != Dict:
+            raise base.InvalidDecoratorException(output_type_error)
+        if fields is None:
+            raise base.InvalidDecoratorException(
+                "When extracting fields from a function that returns a generic 
`dict`, you must "
+                "supply either a `dict` (`typing.Dict`) mapping field names to 
types or "
+                "alternatively a `list` (`typing.List`) of field names."
+            )
+        output_args = typing_inspect.get_args(output_type)
+        if len(output_args) != 2:
+            raise base.InvalidDecoratorException(
+                f"When extracting fields from a function that returns a 
generic `dict`, you "
+                f"must specify only two type parameters (key, value), not 
{output_args}."
+            )
+        if isinstance(fields, list):
+            fields = {field: output_args[1] for field in fields}  # Infer type 
from annotation
+    elif typing_extensions.is_typeddict(output_type):
+        typed_dict_fields = typing.get_type_hints(output_type)  # Dict of 
field name -> type
+        errors = []
+        if fields is None:
+            fields = typed_dict_fields  # Infer fields and types from 
annotation
+        elif isinstance(fields, list):
+            reduced_fields = {}
+            for field in fields:
+                if field not in typed_dict_fields:
+                    errors.append(f"{field} is not a field in the `TypedDict` 
{output_type}.")
+                reduced_fields[field] = typed_dict_fields[field]
+            fields = reduced_fields
+        elif isinstance(fields, dict):
+            for field_name, field_type in fields.items():
+                expected_type = typed_dict_fields.get(field_name, None)
+                if expected_type is None:
+                    errors.append(f"{field_name} is not a field in the 
`TypedDict` {output_type}.")
+                    continue
+                elif expected_type == field_type or 
htypes.custom_subclass_check(
+                    field_type, expected_type
+                ):
+                    continue
+                errors.append(
+                    f"Error {field_name} did not match the TypedDict 
annotation's field "
+                    f"{field_type}. Expected {expected_type}."
+                )
+        if errors:
+            raise base.InvalidDecoratorException(
+                f"Error {fields} did not match a subset of the TypedDict 
annotation's fields "
+                f"{typed_dict_fields}. The following fields were not valid: 
{errors}."
+            )
+    else:
+        raise base.InvalidDecoratorException(output_type_error)
+
+    assert isinstance(fields, dict), "Internal error: fields should be a dict 
at this point."
+    _validate_extract_fields(fields)
+
+    return fields
+
+
 def _validate_extract_fields(fields: dict):
     """Validates the fields dict for extract field.
     Rules are:
@@ -739,61 +822,43 @@ def _validate_extract_fields(fields: dict):
 class extract_fields(base.SingleNodeNodeTransformer):
     """Extracts fields from a dictionary of output."""
 
-    def __init__(self, fields: dict = None, fill_with: Any = None):
+    output_type: Any
+    resolved_fields: Dict[str, Type]
+
+    def __init__(
+        self,
+        fields: Optional[Union[Dict[str, Any], List[str], Any]] = None,
+        *others,
+        fill_with: Any = None,
+    ):
         """Constructor for a modifier that expands a single function into the 
following nodes:
 
         - n functions, each of which take in the original dict and output a 
specific field
         - 1 function that outputs the original dict
 
-        :param fields: Fields to extract. A dict of 'field_name' -> 
'field_type'.
+        :param fields: Fields to extract. Can be a dict of field names to 
types, a list of field names, or a single field name.
+        :param others: Additional fields names to extract - argument 
unpacking. Ignored if `fields` is a dict.
         :param fill_with: If you want to extract a field that doesn't exist, 
do you want to fill it with a default \
         value? Or do you want to error out? Leave empty/None to error out, set 
fill_value to dynamically create a \
         field value.
         """
         super(extract_fields, self).__init__()
+        if isinstance(fields, list):
+            fields = fields + list(others)
+        elif fields and not isinstance(fields, dict):
+            fields = [fields] + list(others)
         self.fields = fields
         self.fill_with = fill_with
 
     def validate(self, fn: Callable):
-        """A function is invalid if it is not annotated with a dict or 
typing.Dict return type.
+        """A function is invalid if it is not annotated with a dict or 
typing.Dict return type or if the
+        fields to extract are not valid.
 
         :param fn: Function to validate.
         :raises: InvalidDecoratorException If the function is not annotated 
with a dict or typing.Dict type as output.
         """
-        output_type = typing.get_type_hints(fn).get("return")
-        if typing_inspect.is_generic_type(output_type):
-            base_type = typing_inspect.get_origin(output_type)
-            if base_type == dict or base_type == Dict:
-                _validate_extract_fields(self.fields)
-            else:
-                raise base.InvalidDecoratorException(
-                    f"For extracting fields, output type must be a dict or 
typing.Dict, not: {output_type}"
-                )
-        elif output_type == dict:
-            _validate_extract_fields(self.fields)
-        elif typing_extensions.is_typeddict(output_type):
-            if self.fields is None:
-                self.fields = typing.get_type_hints(output_type)
-            else:
-                # check that fields is a subset of TypedDict that is defined
-                typed_dict_fields = typing.get_type_hints(output_type)
-                for field_name, field_type in self.fields.items():
-                    expected_type = typed_dict_fields.get(field_name, None)
-                    if expected_type == field_type:
-                        pass  # we're definitely good
-                    elif expected_type is not None and 
htypes.custom_subclass_check(
-                        field_type, expected_type
-                    ):
-                        pass
-                    else:
-                        raise base.InvalidDecoratorException(
-                            f"Error {self.fields} did not match a subset of 
the TypedDict annotation's fields {typed_dict_fields}."
-                        )
-            _validate_extract_fields(self.fields)
-        else:
-            raise base.InvalidDecoratorException(
-                f"For extracting fields, output type must be a dict or 
typing.Dict, not: {output_type}"
-            )
+        self.output_type = typing.get_type_hints(fn).get("return")
+        self.resolved_fields = _determine_fields_to_extract(self.fields, 
self.output_type)
 
     def transform_node(
         self, node_: node.Node, config: Dict[str, Any], fn: Callable
@@ -813,53 +878,38 @@ class extract_fields(base.SingleNodeNodeTransformer):
         # if fn is async
         if inspect.iscoroutinefunction(fn):
 
-            async def dict_generator(*args, **kwargs):
+            async def dict_generator(*args, **kwargs):  # type: ignore
                 dict_generated = await fn(*args, **kwargs)
                 if self.fill_with is not None:
-                    for field in self.fields:
+                    for field in self.resolved_fields:
                         if field not in dict_generated:
                             dict_generated[field] = self.fill_with
                 return dict_generated
 
         else:
 
-            def dict_generator(*args, **kwargs):
+            def dict_generator(*args, **kwargs):  # type: ignore
                 dict_generated = fn(*args, **kwargs)
                 if self.fill_with is not None:
-                    for field in self.fields:
+                    for field in self.resolved_fields:
                         if field not in dict_generated:
                             dict_generated[field] = self.fill_with
                 return dict_generated
 
         output_nodes = [node_.copy_with(callabl=dict_generator)]
 
-        for field, field_type in self.fields.items():
+        for field, field_type in self.resolved_fields.items():
             doc_string = base_doc  # default doc string of base function.
 
-            # if fn is async
-            if inspect.iscoroutinefunction(fn):
-
-                async def extractor_fn(field_to_extract: str = field, 
**kwargs) -> field_type:
-                    dt = kwargs[node_.name]
-                    if field_to_extract not in dt:
-                        raise base.InvalidDecoratorException(
-                            f"No such field: {field_to_extract} produced by 
{node_.name}. "
-                            f"It only produced {list(dt.keys())}"
-                        )
-                    return kwargs[node_.name][field_to_extract]
-
-            else:
-
-                def extractor_fn(
-                    field_to_extract: str = field, **kwargs
-                ) -> field_type:  # avoiding problems with closures
-                    dt = kwargs[node_.name]
-                    if field_to_extract not in dt:
-                        raise base.InvalidDecoratorException(
-                            f"No such field: {field_to_extract} produced by 
{node_.name}. "
-                            f"It only produced {list(dt.keys())}"
-                        )
-                    return kwargs[node_.name][field_to_extract]
+            # This extractor is constructed to avoid closure issues.
+            def extractor_fn(field_to_extract: str = field, **kwargs) -> 
field_type:  # type: ignore
+                dt = kwargs[node_.name]
+                if field_to_extract not in dt:
+                    raise base.InvalidDecoratorException(
+                        f"No such field: {field_to_extract} produced by 
{node_.name}. "
+                        f"It only produced {list(dt.keys())}"
+                    )
+                return kwargs[node_.name][field_to_extract]
 
             output_nodes.append(
                 node.Node(
@@ -867,15 +917,16 @@ class extract_fields(base.SingleNodeNodeTransformer):
                     field_type,
                     doc_string,
                     extractor_fn,
-                    input_types={node_.name: dict},
+                    input_types={node_.name: self.output_type},
                     tags=node_.tags.copy(),
                 )
             )
         return output_nodes
 
 
-def _process_unpack_fields(fields: List[str], output_type: Any) -> List[Type]:
-    """Processes the fields and base output type to extract a list of field 
types.
+def _determine_fields_to_unpack(fields: List[str], output_type: Any) -> 
List[Type]:
+    """Determines which fields to unpack based on user requested fields and 
the output type of
+    the return type of the function.
 
     :param fields: List of fields to to unpack.
     :param output_type: The output type of the node function.
@@ -957,8 +1008,13 @@ class unpack_fields(base.SingleNodeNodeTransformer):
 
     @override
     def validate(self, fn: Callable):
+        """Validates that the return type of the function is a tuple or 
typing.Tuple with the
+
+        :param fn: Function to validate
+        :raises: InvalidDecoratorException If the function does not output a 
tuple or typing.Tuple type.
+        """
         output_type = typing.get_type_hints(fn).get("return")
-        field_types = _process_unpack_fields(self.fields, output_type)
+        field_types = _determine_fields_to_unpack(self.fields, output_type)
         self.field_types = field_types
         self.output_type = output_type
 
@@ -966,6 +1022,14 @@ class unpack_fields(base.SingleNodeNodeTransformer):
     def transform_node(
         self, node_: node.Node, config: Dict[str, Any], fn: Callable
     ) -> Collection[node.Node]:
+        """Unpacks the specified fields form the tuple output into separate 
nodes.
+
+        :param node_: Node to transform
+        :param config: Config to use
+        :param fn: Function to unpack fields from. Must output a tuple.
+        :return: A collection of nodes --
+                one for the original tuple generator, and another for each 
field to unpack.
+        """
         fn = node_.callable
         base_doc = node_.documentation
         base_tags = node_.tags.copy()
diff --git a/tests/function_modifiers/test_expanders.py 
b/tests/function_modifiers/test_expanders.py
index f272cfd5..33292fe4 100644
--- a/tests/function_modifiers/test_expanders.py
+++ b/tests/function_modifiers/test_expanders.py
@@ -333,24 +333,6 @@ class MyDictBad(TypedDict):
     test2: str
 
 
[email protected](
-    "return_type",
-    [
-        dict,
-        Dict,
-        Dict[str, str],
-        Dict[str, Any],
-        MyDict,
-    ],
-)
-def test_extract_fields_validate_happy(return_type):
-    def return_dict() -> return_type:
-        return {}
-
-    annotation = function_modifiers.extract_fields({"test": int})
-    annotation.validate(return_dict)
-
-
 class SomeObject:
     pass
 
@@ -369,95 +351,306 @@ class MyDictInheritanceBadCase(TypedDict):
     test2: str
 
 
-def test_extract_fields_validate_happy_inheritance():
-    def return_dict() -> MyDictInheritance:
-        return {}
-
-    annotation = function_modifiers.extract_fields({"test": InheritedObject})
-    annotation.validate(return_dict)
-
-
-def test_extract_fields_validate_not_subclass():
-    def return_dict() -> MyDictInheritanceBadCase:
-        return {}
-
-    annotation = function_modifiers.extract_fields({"test": SomeObject})
-    with pytest.raises(base.InvalidDecoratorException):
-        annotation.validate(return_dict)
-
-
 @pytest.mark.parametrize(
-    "return_type",
-    [(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)],
+    "return_type_str,fields",
+    [
+        ("Dict[str, int]", ("A", "B")),
+        ("Dict[str, int]", (["A", "B"])),
+        ("Dict", {"A": str, "B": int}),
+        ("MyDict", ()),
+        ("MyDict", {"test2": str}),
+        ("MyDictInheritance", {"test": InheritedObject}),
+        pytest.param("dict[str, int]", ("A", "B"), 
marks=skipif(**prior_to_py39)),
+        pytest.param("dict[str, int]", (["A", "B"]), 
marks=skipif(**prior_to_py39)),
+        pytest.param("dict", {"A": str, "B": int}, 
marks=skipif(**prior_to_py39)),
+    ],
 )
-def test_extract_fields_validate_errors(return_type):
-    def return_dict() -> return_type:
-        return {}
-
-    annotation = function_modifiers.extract_fields({"test": int})
-    with 
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
-        annotation.validate(return_dict)
+def test_extract_fields_valid_annotations_for_inferred_types(return_type_str, 
fields):
+    return_type = eval(return_type_str)
 
+    def function() -> return_type:  # type: ignore
+        return {}  # Only testing validation, so return value doesn't matter
 
-def test_extract_fields_typeddict_empty_fields():
-    def return_dict() -> MyDict:
-        return {}
+    if isinstance(fields, tuple):
+        annotation = function_modifiers.extract_fields(*fields)
+    else:
+        annotation = function_modifiers.extract_fields(fields)
+    annotation.validate(function)
 
-    # don't need fields for TypedDict
-    annotation = function_modifiers.extract_fields()
-    annotation.validate(return_dict)
 
[email protected](
+    "return_type_str,fields",
+    [
+        ("Dict", ("A", "B")),
+        ("Dict", (["A", "B"])),
+        ("Dict", (["A"])),
+        ("Dict", (["A", "B", "C"])),
+        ("int", {"A": int}),
+        ("list", {"A": int}),
+        ("np.ndarray", {"A": int}),
+        ("pd.DataFrame", {"A": int}),
+        ("MyDictBad", {"A": int}),
+        ("MyDictInheritanceBadCase", {"A": SomeObject}),
+        pytest.param("dict", ("A", "B"), marks=skipif(**prior_to_py39)),
+        pytest.param("dict", (["A", "B"]), marks=skipif(**prior_to_py39)),
+        pytest.param("dict", (["A"]), marks=skipif(**prior_to_py39)),
+        pytest.param("dict", (["A", "B", "C"]), marks=skipif(**prior_to_py39)),
+    ],
+)
+def 
test_extract_fields_invalid_annotations_for_inferred_types(return_type_str, 
fields):
+    return_type = eval(return_type_str)
 
-def test_extract_fields_typeddict_subset():
-    def return_dict() -> MyDict:
-        return {}
+    def function() -> return_type:  # type: ignore
+        return {}  # Only testing validation, so return value doesn't matter
 
-    # test that a subset of fields is fine
-    annotation = function_modifiers.extract_fields({"test2": str})
-    annotation.validate(return_dict)
+    if isinstance(fields, tuple):
+        annotation = function_modifiers.extract_fields(*fields)
+    else:
+        annotation = function_modifiers.extract_fields(fields)
+    with 
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
+        annotation.validate(function)
 
 
-def test_valid_extract_fields():
-    """Tests whole extract_fields decorator."""
+def test_extract_fields_transform_on_bare_dict_with_explicit_types():
+    """Tests whole extract_fields decorator using a bare, non-generic, dict 
and explicit types."""
     annotation = function_modifiers.extract_fields(
         {"col_1": list, "col_2": int, "col_3": np.ndarray}
     )
 
-    def dummy_dict_generator() -> dict:
+    def dummy_dict() -> dict:  # bare dict, not generic
         """dummy doc"""
         return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 
3, 4])}
 
-    nodes = list(
-        annotation.transform_node(node.Node.from_fn(dummy_dict_generator), {}, 
dummy_dict_generator)
-    )
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
     assert len(nodes) == 4
     assert nodes[0] == node.Node(
-        name=dummy_dict_generator.__name__,
+        name=dummy_dict.__name__,
         typ=dict,
-        doc_string=dummy_dict_generator.__doc__,
-        callabl=dummy_dict_generator,
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
         tags={"module": "tests.function_modifiers.test_expanders"},
     )
     assert nodes[1].name == "col_1"
     assert nodes[1].type == list
     assert nodes[1].documentation == "dummy doc"  # we default to base 
function doc.
-    assert nodes[1].input_types == {dummy_dict_generator.__name__: (dict, 
DependencyType.REQUIRED)}
+    assert nodes[1].input_types == {dummy_dict.__name__: (dict, 
DependencyType.REQUIRED)}
     assert nodes[2].name == "col_2"
     assert nodes[2].type == int
     assert nodes[2].documentation == "dummy doc"
-    assert nodes[2].input_types == {dummy_dict_generator.__name__: (dict, 
DependencyType.REQUIRED)}
+    assert nodes[2].input_types == {dummy_dict.__name__: (dict, 
DependencyType.REQUIRED)}
     assert nodes[3].name == "col_3"
     assert nodes[3].type == np.ndarray
     assert nodes[3].documentation == "dummy doc"
-    assert nodes[3].input_types == {dummy_dict_generator.__name__: (dict, 
DependencyType.REQUIRED)}
+    assert nodes[3].input_types == {dummy_dict.__name__: (dict, 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_explicit_types():
+    """Tests whole extract_fields decorator using a generic dict and explicit 
types."""
+    annotation = function_modifiers.extract_fields({"col_1": int, "col_2": 
int})
+
+    def dummy_dict() -> Dict[str, int]:
+        """dummy doc"""
+        return {"col_1": 1, "col_2": 2}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 3
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=Dict[str, int],
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "col_1"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"  # we default to base 
function doc.
+    assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "col_2"
+    assert nodes[2].type == int
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_field_list():
+    """Tests whole extract_fields decorator using a generic dict and a list of 
field names."""
+    annotation = function_modifiers.extract_fields(["col_1", "col_2"])
+
+    def dummy_dict() -> Dict[str, int]:
+        """dummy doc"""
+        return {"col_1": 1, "col_2": 2}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 3
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=Dict[str, int],
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "col_1"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"  # we default to base 
function doc.
+    assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "col_2"
+    assert nodes[2].type == int
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_generic_dict_with_unpacked_fields():
+    """Tests whole extract_fields decorator using a generic dict and unpacked 
field names."""
+    annotation = function_modifiers.extract_fields("col_1", "col_2")
+
+    def dummy_dict() -> Dict[str, int]:
+        """dummy doc"""
+        return {"col_1": 1, "col_2": 2}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 3
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=Dict[str, int],
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "col_1"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"  # we default to base 
function doc.
+    assert nodes[1].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "col_2"
+    assert nodes[2].type == int
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy_dict.__name__: (Dict[str, int], 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_explicit_types():
+    """Tests whole extract_fields decorator using a TypedDict and explicit 
types."""
+    annotation = function_modifiers.extract_fields({"test2": str})
+
+    def dummy_dict() -> MyDict:
+        """dummy doc"""
+        return {"test": 1, "test2": "2"}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 2
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=MyDict,
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "test2"
+    assert nodes[1].type == str
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_field_list():
+    """Tests whole extract_fields decorator using a TypedDict and a list of 
field names."""
+    annotation = function_modifiers.extract_fields(["test2"])
+
+    def dummy_dict() -> MyDict:
+        """dummy doc"""
+        return {"test": 1, "test2": "2"}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 2
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=MyDict,
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "test2"
+    assert nodes[1].type == str
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_unpacked_fields():
+    """Tests whole extract_fields decorator using a TypedDict and explicit 
types."""
+    annotation = function_modifiers.extract_fields("test2")
+
+    def dummy_dict() -> MyDict:
+        """dummy doc"""
+        return {"test": 1, "test2": "2"}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 2
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=MyDict,
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "test2"
+    assert nodes[1].type == str
+    assert nodes[1].documentation == "dummy doc"
+    assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, 
DependencyType.REQUIRED)}
+
+
+def test_extract_fields_transform_on_typed_dict_with_inferred_types():
+    """Tests whole extract_fields decorator using a TypedDict and inferred 
types."""
+    annotation = function_modifiers.extract_fields()
+
+    def dummy_dict() -> MyDict:
+        """dummy doc"""
+        return {"test": 1, "test2": "2"}
+
+    annotation.validate(dummy_dict)
+    nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
+
+    assert len(nodes) == 3
+    assert nodes[0] == node.Node(
+        name=dummy_dict.__name__,
+        typ=MyDict,
+        doc_string=getattr(dummy_dict, "__doc__", ""),
+        callabl=dummy_dict,
+        tags={"module": "tests.function_modifiers.test_expanders"},
+    )
+
+    assert nodes[1].name == "test"
+    assert nodes[1].type == int
+    assert nodes[1].documentation == "dummy doc"  # we default to base 
function doc.
+    assert nodes[1].input_types == {dummy_dict.__name__: (MyDict, 
DependencyType.REQUIRED)}
+    assert nodes[2].name == "test2"
+    assert nodes[2].type == str
+    assert nodes[2].documentation == "dummy doc"
+    assert nodes[2].input_types == {dummy_dict.__name__: (MyDict, 
DependencyType.REQUIRED)}
 
 
-def test_extract_fields_fill_with():
+def test_extract_fields_transform_using_fill_with():
     def dummy_dict() -> dict:
         """dummy doc"""
         return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 
3, 4])}
 
     annotation = function_modifiers.extract_fields({"col_2": int, "col_4": 
float}, fill_with=1.0)
+    annotation.validate(dummy_dict)
     original_node, extracted_field_node, missing_field_node = 
annotation.transform_node(
         node.Node.from_fn(dummy_dict), {}, dummy_dict
     )
@@ -468,18 +661,19 @@ def test_extract_fields_fill_with():
     assert missing_field == 1.0
 
 
-def test_extract_fields_no_fill_with():
+def test_extract_fields_transform_not_using_fill_with():
     def dummy_dict() -> dict:
         """dummy doc"""
         return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 
3, 4])}
 
     annotation = function_modifiers.extract_fields({"col_4": int})
+    annotation.validate(dummy_dict)
     nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, 
dummy_dict))
     with 
pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException):
         nodes[1].callable(dummy_dict=dummy_dict())
 
 
-def test_unpack_fields_valid_explicit_tuple():
+def test_unpack_fields_transform_on_explicit_tuple():
     def dummy() -> Tuple[int, str, int]:
         """dummy doc"""
         return 1, "2", 3
@@ -510,7 +704,7 @@ def test_unpack_fields_valid_explicit_tuple():
     assert nodes[3].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
 
 
-def test_unpack_fields_valid_explicit_tuple_subset():
+def test_unpack_fields_transform_on_explicit_tuple_subset():
     def dummy() -> Tuple[int, str, int]:
         """dummy doc"""
         return 1, "2", 3
@@ -533,7 +727,7 @@ def test_unpack_fields_valid_explicit_tuple_subset():
     assert nodes[1].input_types == {dummy.__name__: (Tuple[int, str, int], 
DependencyType.REQUIRED)}
 
 
-def test_unpack_fields_valid_indeterminate_tuple():
+def test_unpack_fields_transform_on_indeterminate_tuple():
     def dummy() -> Tuple[int, ...]:
         """dummy doc"""
         return 1, 2, 3


Reply via email to