Previous Episode: Expect tests
Next Episode: Random number generators

What is vmap? How is it implemented? How does our implementation compare to JAX's? What is a good way of understanding what vmap does? What's up with random numbers? Why are there some issues with the vmap that PyTorch currently ships?

Further reading.

Tracking issue for vmap support https://github.com/pytorch/pytorch/issues/42368BatchedTensor source code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/BatchedTensorImpl.h , logical-physical transformation helper code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/VmapTransforms.h (well documented, worth a read)functorch, the better, more JAX-y implementation of vmap https://github.com/facebookresearch/functorchAutodidax https://jax.readthedocs.io/en/latest/autodidax.html which contains a super simple vmap implementation that is a good model for the internal implementation that PyTorch has