11:  @functools.partial(jax.checkpoint, prevent_cse=False)

I think checkpointing relates to limiting memory used by gradient
backpropagation during training of a model.  I think it means the
gradients can be recalculated for this function when needed, by
storing its arguments instead of each gradient.

Reply via email to