I really liked the suggestion that if it takes off, the web should consider trying to expose something like the OpenXLA intermediate model, which powers the new PyTorch 2.0, TensorFlow, Jax, and a bunch of other top tier ML frameworks.
It already is very well optimized for a ton of hardware (cpus, gpus, ml chips). The Intermediate Representation might already be a web-safe-ish model, effectively self-sandboxing, which could make it safe to expose.
Making each webapp target & optimize ML for every possible device target sounds terrible.
The purpose of MLIR is that most of the optimization can be done at lower levels. Instead of everyone figuring out & deciding on their own how best to target & optimize for js, wasm, webgl, and/or webgpu, you just use the industry standard intermediate representation & let the browser figure out the tradeoffs. If there is inboard hardware, neural cores, they might just work!
Good to see WebML has OpenXLA on their radar... but also a bit afraid, expecting some half ass excuses why of course we're going to make some brand new other thing instead. The web & almost everyone else has such a bad NIH problem. WASI & web file apis being totally different is one example, where there's just no common cause, even though it'd make all the difference. And with ML, the cost of having your own tech versus being able to re-use the work everyone else puts on feels like a near suicidal decision to make an API that will never be good, never perform anywhere where near it could.
> Making each webapp target & optimize ML for every possible device target sounds terrible.
Yes it does.
Did something I said imply that?
OpenXLA is an intermediate layer that frameworks like PyTorch or JAX can use. It has pluggable backends, and so if there was a web-compatible backend (WebGL or WASM) then everyone could use it and all models that were built using something that used OpenXLA[1] would be compatible.
[1] Not 100% sure how low-level the OpenXLA intermediate representation is. I know it's not uncommon when porting a brand new primitive (eg a special kind of transformer etc) to a new architecture (eg CUDA->Apple M1) that some operations aren't yet supported, so this might be similar.
I support having web targets. It'd be a good offering.
But it feels upside down to me from what we really all should want, which is a safe way to let the web target any backend you have. WebGPU or WebGL or wasm are going to be OK targets, but with limited hardware support & tons of constraints that mean they won't perform as well as openxla.
Also how will these targets get profiled? Do we ship the same WebGL to a 600w monster as a rpi?
There's a lot of really good reasons to want OpenXLA under the browser, rather than above/before it.
> WebGPU or WebGL or wasm are going to be OK targets, but with limited hardware support & tons of constraints that mean they won't perform as well as wasm.
I don't understand. "WebGPU or WebGL or wasm".. "won't perform as well as wasm".
I don't think a high level representation is necessary for relatively straightforward FMA extensions (either outer products in the case of Apple AMX or matrix products in the case of CUDA/Intel AMX). WebGPU + tensor core support and WASM + AMX support would be simpler to implement, likely more future proof and wouldn't require maintaining a massive layer of abstraction.
The issue is, much of the performance of Pytorch, JAX, et al comes from running a JIT that is tuned to the underlying HW, and come with support for high level intrinsic operations that were either hand-tuned or have extra hardware support, especially ops dealing with parallelizing computation across multiple cores.
You'd probably end up representing these as external library function calls in WASM, but then the WASM JIT would have to be taught that these are magic functions that are potentially treated specially, so at that point you're just embedding HLO ops as library func, and them embedding an HLO translator into the WASM runtime, I'm not sure that's any better.
By analogy would be be better to eliminate fragment and vertex shaders and just use WASM for sending shaders to the GPU, or is the domain specific language and its constraints beneficial to the GPU drivers?
Do we leave it to every web app to figure out how best to serve everyone, and have them bundle their own tuning optimizers into each app? Or do we bake in a higher level abstraction that works for everyone that the browser itself will be able to help optimize?
There's some risk & the browser apis likely won't come with all the escape-hatches the full tools might have to manually jigger with optimizations, but the idea of getting everyone to DIY seems like a promise of misfit: way too much code when you don't need it, way not enough tuning when you do need it. And there's other risks; the assurity that oh we just need one or maybe two ops on the web & then everything will be fine forever doesn't wash with me. If we make new ops the old code won't use it.
And what about hardware that doesn't have any presence on the web; lots of cheap embedded cores have a couple tflops of neural coprocessing, but neither wasm nor webgpu can target that atm, it's much too simple a core for that kind of dynamic execution; it's the sea of weird expansive hardware that OpenXLA helps one target (and target very well indeed) that is it's chief capability, and I can't imagine forgoing a middleman abstraction like it.
checkout https://mlc.ai/web-stable-diffusion, which is builds on top of Apache TVM and brings in models from PyTorch2.0, ONNX and other means into the ML compilation flow
It already is very well optimized for a ton of hardware (cpus, gpus, ml chips). The Intermediate Representation might already be a web-safe-ish model, effectively self-sandboxing, which could make it safe to expose.
https://news.ycombinator.com/item?id=35078410