Hacker News new | past | comments | ask | show | jobs | submit login
Pytrees (jax.readthedocs.io)
132 points by f_devd on May 22, 2023 | hide | past | favorite | 14 comments



For those curious what the big deal is here: PyTrees make it wildly easier to take derivatives with respect to parameters involving a complex structure. This makes it much easier to organize code for non-trivial models.

As an example: if you want to implement logistic regression in JAX, you need to optimize the weights. This is easy enough since this can be modeled as a single value, a matrix of weights. If you want to model a 2 layer MLP, now you have to use 2 matrices of weights (at least). You could treat this as two parameters to your function (which makes the derivative more complicated to manage) or you could concatenate the weights and split them up, etc. Annoying, but managable.

When you get to something like a diffusion model you now need to manage parameters for a variety of different, quite complex, models. It really helps if you can keep track of all these parameters in whatever data structure you like, but also trivially just call "grad" with regard to these and get your models derivative with respect to its parameters.

Pytrees make this incredibly simple, and is a major quality of life improvement in automatic differentiation.


There is also the standalone library "tree" from DeepMind: https://github.com/deepmind/tree

It provides similar functionality but is standalone and does not depend on JAX, TF or anything else.


JAX's use of pytrees is great! They implemented a lot of useful utility functions, namely `tree_map`, that makes working with these objects easy and intuitive. I recommend looking at their neural network example library "stax".


One curious thing I discovered a few months ago: you can sort of hack higher-order functions into JAX by defining “Pytree closures” which introspect on normal closures, and pull out the JAX tracer data from the closure environment (and put it back in, when tracing is required) —- and this works! You can pass these Pytree closures in and out of JIT boundaries, etc.

I believe JAX has a utility for this somewhere, can’t quite remember what this is called.

I typically think of JAX as quite restrictive — but I think the reality is that the only real limit on expressivity is that you can’t dynamically allocate inside of unbounded control flow (e.g. creating new allocations inside of a while loop).


You're thinking of `jax.closure_convert`. :)

(Although technically that works by tracing and extracting all constants from the jaxpr, rather than introspecting the function's closure cells -- it sounds like your trick is the latter.)

When you discuss dynamic allocation, I'm guessing you're mainly referring to not being able to backprop through `jax.lax.while_loop`. If so, you might find `equinox.internal.while_loop` interesting, which is an unbounded while loop that you can backprop through! The secret sauce is to use a treeverse-style checkpointing scheme.

https://github.com/patrick-kidger/equinox/blob/f95a8ba13fb35...


Shameless advert -- Equinox is a neural network library for JAX based entirely around pytrees:

https://github.com/patrick-kidger/equinox

(Now on 1.1k stars so it's achieved some popularity!)

This makes model-building elegant (IMO), without any new abstractions to learn. Quite a PyTorch-like experience overall.


You might like diffrax too. ;)


diffrax is absolutely magical. I had to integrate a lot of ODEs during my PhD, so I spent quite some time choosing and tuning the scipy solvers for my problems, and I thought that I came close to the fastest I could do in Python. Recently, out of curiosity I rewrote a stiff system that I was studying to solve it with diffrax, and was astonished when I saw it being solved 150x faster.


This is awesome to hear!

And if you're doing implicit methods, then you may be interested to hear that today's release of Diffrax now includes IMEX solvers! Sil3, KenCarp3, KenCarp4, KenCarp5.


I did see that on my GitHub feed this morning, it was the good news of the day :) Thank you for creating and maintaining this library! I also really enjoyed your blogpost "How to succeed in a machine learning PhD" which is a trove of interesting things to learn.


That's insane. May I ask what scipy solver you used and what you switched to in diffrax? What kind of ODE were you working with?


It is a system of four coupled ODEs from biology. I was using LSODA with relatively tight rtol and atol. Now in diffrax I use Tsit5, with adaptive step size control.

Something really annoying about this system is that for a non-negligible portion of the parameter space, an adaptive step solver would fail and give up at some point during the integration because the step size converges to zero. This was preventing me from doing large parameter searches. Now, because diffrax makes it easy to specify the step size controller, I first try to solve the system with a PID step size control, and if this fails, I rerun it with small fixed steps, which is slower but always goes through the end. This guarantees that I will always get a complete solution, and that it will still be fast in 95% of cases, which is really a huge improvement.


Okay, interesting. For such a small system I can imagine the jit compilation helps a lot.

Using Numba to jit compile the rhs might speed things up in the non-diffrax version.

Does the LSODA solver not take a minimum time step parameter? Setting that to your small fixed step might also help.


Incidentally `diffrax.PIDController` also has a `dtmin` argument that could probably be used here instead of re-running things. :)




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: