both tests passing, now to normalise things a bit. might be a little hard to rewrite the torch scan() shim to match the approach that works with jax
commit b60bc067e16c717fc6632d862f1de275007aa47e (HEAD ->
return_weights, origin/return_weights)
Date: Fri Jan 28 05:43:46 2022 +0000
jax return_weights working, commented draft statements left in
source, implementations not normalised
