Hacker News new | past | comments | ask | show | jobs | submit login
Training Deep Networks with Data Parallelism in Jax (mishalaskin.com)
122 points by sebg on Feb 24, 2023 | hide | past | favorite | 37 comments



JAX is such a beautiful system. There's many deep PL ideas in the design of JAX, one could spend years thinking about them. It's wonderfully fun to design new interpreters, implementing some semantics, stage them out -- automatically gain access to JIT compilation and accelerators via JAX's other higher-order primitives/transformations.

I've become a big believer that it would be beneficial for PL research in ML which makes heavy use of program transformations to provide small JAX-based implementations. There's really no other system which allows you to express interpreter-based transformations with the benefits that JAX provides (maybe `functorch` in a few months? I have some doubts of transformation composition with systems like torchdynamo - but I don't know much about it)

Edit: note this is coming from a long time Julia stan, take that for what it is worth :)


Since you're a Julia stan: Don't you think that the program transformations that JAX is doing could be done much more easily in Julia since Julia has macros? Aren't there things that are similar to JAX in the Julia ecosystem? (ie. a few different autodiff packages that do program transformation)


No — macros are a non-intrusive transformation — they don’t allow you to transform callees of a function, unless you wrap the macro around the callee. People have tried this in Julia, and it’s horribly slow.

There’s another mechanism in Julia - generated functions. These allow method body specialization given type knowledge about the signature of the function — so a user can write code which is generated for the method body when inference determines the signature (and the inferred signature is tight enough) which depends on the inferred types.

All of Julia’s program transformation based AD packages are based on the latter transformation — most of them do terrible things to the compiler, including massively blowing up the size of code before optimization.

The only package which is more promising is Diffractor — but I’m not convinced it is more than a research prototype at its current level of development. That may change. This was written by one of the compiler devs, and uses lower level hooks into the compiler, developed to support its transformation.

The big issue in general: Julia doesn’t let you write transformations on its typed IR from user space, unless you want to ignore Julia’s native execution engine. There are hooks that someone can work with — but they aren’t user-facing (for all but the most advanced users) — and they break pass composability with code generation using the native engine (this may have changed since I last looked at this!) I would know, because I’ve spent several attempts trying to do stuff like this, and making crappy, unstable packages :)

Separately: macros are one level of reflection -> code generation. JAX supports a different form — you can’t emit data which represents generic expressions — it’s not quite like Lisp in that sense. It’s better to think about JAX as a “two-level” language system — where you have a meta-level which is Python, and there’s a statically typed array language which is the object level. JAX supports a stage -like operation which allows transforming compat subset of Python to the statically typed array language. But you can write interpreters in the full dynamism of Python — as long as the “active paths” (under tracers) are in that compat set, you can then stage out applications of the interpreters on Python functions, etc.

JAX provides one solution to the composable transformation problem, and they’ve done it in an elegant way - that's ~pretty~ easy to understand (c.f. the Autodidax tutorial). With my current knowledge of things, I can’t effectively argue that Julia supports the same right now (caveat: things may have changed since I last had a look). This is an area where a lot of stuff seems to be going on in Julia, so I doubt it will remain this way forever.


Very informative, thanks.

> Julia doesn’t let you write transformations on its typed IR from user space, unless you want to ignore Julia’s native execution engine. There are hooks [...] but [...] they break pass composability with code generation using the native engine

Can you elaborate on that? What is pass composability? Are we talking about LLVM's IR or is there another specific to Julia?


I’m talking about passes which operate on either (a) CodeInfo - a high-level IR which is the data interface for Julia’s abstract interpretation based type inference an (b) IRCode which is what Julia’s optimizer uses. Note: again, I haven’t looked recently —- but this is what I remember from a number of months ago.

When I discussed writing passes, I was referring to interacting with these two phases of the compiler (abstract interpreter and optimization). In practice, these two phases are interlinked.

Shuffling data between these two phases make a lot of assumptions which are mostly opaque to users. Like I said, you can hack it — but it’s hard to learn what you need to know, and there’s not a stable interface or a nice “this is how you write a transformation to operate on this IR” or “this is how you write a custom opt”.

In any case, I’m not totally convinced that it’s a good idea to expose this stuff to user libraries. Or, at least, it needs to be carefully thought about.

See some of the complaints about “magic” in this post for some of that. I’m just fascinated by this stuff for some weird reason.


I understand, thank you for the details.


I'm not the person you're responding to, but here's my take on it.

> What is pass composability?

Pass composability is something that comes up in julia a lot because our custom compiler passes are often done in the user space and have all sorts of interesting applications. The idea is just that we want to have multiple program transformations occuring at once.

I.e. suppose I'm using some sort of program transformation to turn regular code into derivative code with automatic differentiation (AD(, but suppose I *also* want to performing a program transformation in order to generate say GPU code, or I want to perform a program transformation that replaces all my heap allocations with allocations onto a Bump Allocator, or something else. One has to take care to make sure these different transformations can cooperate with eachother. Hell, it can even occur when one wants to do higher order AD that you have to stack two AD passes on top of eachother.

One problem here is that layering passes on top of eachother can cause a combinatorial explosion in the amount of generated code if things aren't being pruned or optimized between passes.

_________________________________

> Are we talking about LLVM's IR or is there another specific to Julia?

The person you were talking to was referring to Julia's own untyped and typed IR's respectively. Julia programs go through quite a few different forms of representation before they end up getting run. The pipeline looks like this:

1) String: Just a regular, unparsed string of text.

2) Expressions: this is a user facing representation of parsed code that our macros operate on. At this level, all that's really been done is parsing and a bit of canonicalization. There is no name or scope resolution done at this level, and everything is in terms of trees.

3) Untyped IR: This is a not-so-user-facing intermediate representation of julia code that is produced after an Expression tree gets linearized into SSA form. This has had name and scope resolution performed on it, but no type inference or optimization passes passes performed on it. Generated functions and various user-level compiler pass injection techniques are able to operate on this level of julia representation.

4) Typed IR: This is actually the same object as untyped IR, just with slots that used to be empty filled in. It has had type inference performed on it, and many of our custom julia optimization passes performed on it. The types here still correspond to julia level types. Ideally, we'd be doing user level pass injections on this level of IR where types are resolved, performing optimizations using those types to prune down the amount of code, and then performing the next program transformation, and so on.

5) LLVM IR: The next step after we're done with the typed IR is to translate it down to LLVM IR. This involves replacing julia types with LLVM types, and a bunch of other stuff. LLVM will then perform its own optimization passes (of our choice) on this IR. Some packages do program transformations on this level of code, for instance Enzyme.jl. One advantage of this is that the work can be easily shared with other LLVM backed languages.

6) Assembly code: The LLVM IR then gets compiled to assembly with involves yet more optimization and translation passes.


Thank you, that's a very useful summary.


What you are saying is fascinating. Unfortunately, my google powers are limited today.

All I can see about program transformations in Jax is [1], and it appears to me there are 4, grad, jit, vmap and pmap. It seems to me you are implying that there are ways to create custom transformations, and this has actual use cases.

Would you mind giving some more details? Or maybe some links? I can't help but be excited by your enthusiasm, it feels like Jax could be the ultimate programming language.

[1] https://github.com/google/jax#transformations


I would start here: https://jax.readthedocs.io/en/latest/notebooks/Writing_custo...

Note that JAX enforces certain limitations, so you should be careful when considering JAX to be "the perfect language" - in general I don't think this is true. It's quite good at what it's designed for.


Thank you


The abstractions provided by JAX for parallelism are beautiful. JAX is an absolute master-class in programming-interface design and a lesson in the power of providing composable primitive operations and FP inspired design. An astounding amount of complexity is hidden from the user behind primitives like pmap, and the power is exposed in such a simple interface.


Thanks for the kind words! We've been doing a lot more work in this direction too, for both compiler-based automatic parallelization [0] and a work-in-progress pmap successor for 'manual' parallelism (per-device code and explicit collectives) [1] which composes seamlessly with the compiler-based stuff.

[0] https://jax.readthedocs.io/en/latest/notebooks/Distributed_a...

[1] https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...


That's true, and is a massive part of what I love about JAX, but they also form barriers in weird parts of your code, preventing standard introspection tools, which is the single thing I hate about JAX. The errors are amazingly opaque.


If you have any particular examples in mind, and time to share them on https://github.com/google/jax/issues, we'd love to try to improve them. Improving error messages is a priority.

About introspection tools, at least for runtime value debugging there is to some extent a fundamental challenge: since jax.jit stages computation out of Python (though jax.grad and jax.vmap don't), it means standard Python runtime value inspection tools, like printing and pdb, can't work under a jax.jit as the values aren't available as the Python code is executing. You can always remove the jax.jit while debugging (or use `with jax.disable_jit(): ...`), but that's not always convenient, and we need jax.jit for good performance.

We recently added some runtime value debugging tools which work even with jax.jit-staged-out code (even in automatically parallelized code!), though they're not the standard introspection tools: see `jax.debug.print` and `jax.debug.breakpoint` on https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/debugging/print_breakpo....

If you were thinking about other kinds of introspection tooling, I'd love to hear about it!


> with jax.disable_jit(): ...

That's handy, and I hadn't seen it before, thanks.

It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928). I'm not sure the exact solution, but the axis juggling and specifications were where I remember a lot of pain, and the docs (though extensive) were unclear. At times it feels like improvements are punted on in the hopes that xmap eventually fixes everything (and xmap has been in experimental for far longer than I expected).

Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out. When it came to doing some complex jax.lax.ppermute, it felt more difficult than it needed to be to debug.

Next time I encounter something particularly opaque, I'll share on the github issue tracker.


Thanks for taking the time to explain these.

> It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928).

We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it.

> Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out.

That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise.

However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!)

> Next time I encounter something particularly opaque, I'll share on the github issue tracker.

Thank you for the constructive feedback!


Thanks! One last thing, since I have your ear. The function transformation aspects of jax seem to make their way into downstream libraries like haiku, resulting in a lot of "magic" that can be difficult to examine and debug. Are there any utils you made to make jax's own transformations more transparent, which you think might be helpful to third party transformations?

Higher order functions are difficult in general, and it would be fantastic to have core patterns or tools for breaking them open.


It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)

If so, then allow me to make my usual advert here for Equinox:

https://github.com/patrick-kidger/equinox

This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)

On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.


As a matter of fact, you’re preaching to the choir! Equinox is my go-to library for jax NN work!


You're right that downstream libraries have often tended to introduce magic (some more than others), and moreover one library's magic is typically incompatible with other libraries'. It's something that we're working on but we don't have much to show for it yet. Two avenues are:

1. as you say, exposing patterns and tools for library authors to implement transformations/higher-order primitives using JAX's machinery rather than requiring each library to introduce bespoke magic to do the same;

2. adding JAX core infrastructure which directly solves the common problems that libraries tend to solve independently (and with bespoke magic).


Thanks for the info! And I want to be clear that I make all these comments and questions out of a love for the tool. It's one of my favorite tools that I wish I could use at work. The design and community engagement are both fantastic.


Agreed. Though keep in mind they built on a lot of failed attempts at doing the same to get here.


A nice simple walkthrough in this post, but would be nice if it was updated to show how to do this with sharding and the new jax.Array type introduced not too long ago

https://github.com/google/jax/pull/11233/files


What are the advantages of this NumPy / Python rewriting approach vs using a VHLL directly targeting the vector accelerator HW, like Futhark?

As a first impression the tower of abstraction (Python - NumPy - functional subset JIT) looks a bit too multilayered for a sustainable foundation.


It seems like the ecosystem is still dominated by PyTorch, is Jax supposed to be a competitor? Any signs of Jax taking over PyTorch anytime soon? Is it perhaps too early for its time? Or is there a critical flaw in the underlying design?


I think it's better to think of JAX as a more general framework for differentiable programming and PyTorch more focused specifically on deep learning/neural networks.

The beauty of JAX is that basic usage is basically a single function: `grad`.

You just write whatever Python function you want and can get the derivative/gradient of it trivially. It gets a bit trickier when you need more sophisticated numeric tools like numpy/scipy, but in those cases it's just about swapping out with a JAX version of those.

In this sense JAX is the spiritual success to Autograd. However the really amazing thing about JAX is that not only do you get the autodiff for basically free, you also get very good performance, and basically GPU parallelism without needing to think about it at all.

PyTorch is an awesome library, but largely focus on building Neural Networks specifically. JAX should be thought of a tool that basically any Python programmer can just throw in there whenever they come across a problem that benefits from having differentiable code (which is a lot of cases once you start thinking about differentiation as a first class feature).


> I think it's better to think of JAX as a more general framework for differentiable programming and PyTorch more focused specifically on deep learning/neural networks.

I don't get the point of this distinction - JAX was developed specifically for ML. What else is it being used for right now?


I have used it for a numerical optimisation library that I wrote as part of a class.

It’s probably also useful in population and metaheuristic scenarios where the optimisation objective can be described mathematically, allowing you to make use of GPGPUs, and if possible first and second order derivatives.


My impression is that people are experimenting more with automatic differentiation for more traditional scientific computing applications.

Although, I've mainly heard about Julia in that context, not Jax.


I think it's young - and perhaps JAX itself is not so specialized to a specific task (but there's plenty of libraries for deep learning focused tooling, although not as mature as PyTorch). It has often been said in other threads on JAX, but it feels like a very different type of library from other AD systems -- the focus on concisely exposing/allowing users to express composable transformations seems novel! (But I may be mistaken)

But in general, I would suspect youth.


So I've been looking into ONNX to speedup inference, is there some killer feature I should look at JAX for?


Generally from what I've seen the biggest inference speedup win with ONNX is to get the model to ONNX then to TRT (TensorRT) - assuming Nvidia hardware. Once in "TRT land" you can play with fp32, fp16, int8 (with calibration), etc. That said (again generally) ONNX does tend to perform better when compared to native (TF savedmodel, pytorch torchscript, whatever). With TensorFlow and Pytorch there are also ways to export/compile directly to TRT but from an ops standpoint I don't prefer this.

Certain inference serving solutions like Nvidia Triton Inference Server will even take an ONNX model and then do TRT compilation (with cache!) on the actual inference hardware dynamically at model load time. This is really nice because you can deploy a standard ONNX model across instances and varying GPU hardware and always get TRT optimized and compatible with Compute Capability, etc. Really handy and basically comes down to a few lines of config in the model configuration.

I'm not terribly familiar with JAX but I have to imagine there's ONNX export or straight to TRT export somewhere.


About a year ago I was tasked with comparing ONNX runtime implementations of certain components like convolution with some of our own in-house implementations. There was just no comparison. ONNX runtime has some pretty crazy fast implementations. Lots of magic in there. Concluded that we weren't going to be able to beat those without a lot of effort and expertise that we didn't have in our team.


Jax is a great tool, but it’s really best for training and experimentation. The transformations outlined in this post (amongst others) make it easy to turn simple and straightforward code into high performance parallel code. While this is changing, inferences hasn’t been a historical area of emphasis for the project, so it wouldn’t be my first choice if that was your primary goal.


Yeah inference is a real weakness. I've used jax2tf to get tf saved models and tflite models, but it was a real journey...

There's some effort going into systems for saving and restoring the computation graph for Jax programs, which will help a lot. I'm surprised it didn't happen sooner, as it seems like quite a natural fit with the jax architecture.


Is Jax useful for Ray/Dask/Spark like distributed computing?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: