Yes, TPU VMs dramatically improve the Cloud TPU user experience. You now have direct access to the VM on each TPU host whether you are using JAX, PyTorch, or TensorFlow, which provides a lot more flexibility and control and can often improve performance.
I struggle to understand precisely what you mean by user experience and ‘often improved performance‘.
Previously, there was no actual support for crucial features of the TPU related to data loading when using PyTorch, say. In turn, using a TPU over a GPU on that setup was frequently not worth it due to that exact issue. Your answer suggests it might be different now: are TF, Jax and PyTorch now on par in all stages?
In the previous Cloud TPU architecture, PyTorch and JAX users had to create a separate CPU VM for every remote TPU host and arrange for these CPU hosts to communicate indirectly with the TPU hosts via gRPC. This was cumbersome and made debugging difficult.
With TPU VMs, none of this is necessary. You can SSH directly into each TPU host machine and install arbitrary software on a VM there to handle data loading and other tasks with much greater flexibility.
The blog post provides an example of training cost improvement using PyTorch / XLA on TPU VMs in the "Local execution of input pipeline" section. Hopefully we will be able to provide more tutorials on using PyTorch / XLA with TPU VMs soon.
With TPU VMs, workloads that require lots of CPU-TPU communication can now do that communication locally instead of going over the network, which can improve performance.
https://news.ycombinator.com/item?id=24721229