Hacker News new | past | comments | ask | show | jobs | submit login

Anybody using it in production? Is it, or its derivatives like Flax, worth using over pyTorch for anything?

edit: Made comparison more fair.




I’m a researcher, not using anything in production, but I find jax more usable as a general GPU-accelerated tensor math library. PyTorch is more specifically targeted at the neural network use case. It can be shoehorned into other use cases, but is clearly designed & documented for NN training & inference.


Agreed. I used Jax about a year ago to estimate some diode parameters for a side project of mine.


Not a fair comparison IMO. Jax is low level library used to make ML frameworks while pytorch is a full blow ML framework.

In terms of is it worth using it - that depends on what you're doing. If you just want to start with ML training probably not. If you have something already and you want to take it to next level (e.g. influence how training and inference work) than it's a good choice. You might be interested in looking into flax or haiku instead of using vanilla Jax. These are closer to pytorch.


If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!) Basically designed to offer PyTorch-like syntax for working with JAX. The latter is excellent for the reasons the sibling replies have stated, but PyTorch absolutely got the usability story correct.


We recently switched to Jax to boost performance as we scale up our algorithmic core. The nice thing is that it presents only a minor jump in capabilities to get developers working with it, if they have prior exposure to numpy ofcourse. Quite nice :)


A number of large AI companies use it to train their large models; Midjourney, Stability, Anthropic, DeepMind, among others.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: