Hacker News new | past | comments | ask | show | jobs | submit login

Wouldn't you use jax.disable_jit() for that?



You can do that too! That will disable every JIT though. In practice you might only want to disable just one.

To add some colour to my answer. When writing a library, it's typical to a put a JIT statement on everything in the public API. This means you get the benefits of JIT compilation even when you're just hacking around in the REPL, and mitigates the new-user-footgun in which they forget to use JIT themselves.

Meanwhile, good practice is always to JIT your whole computation.

Combined, this mean that it's fairly common to go jit (at the top level) -> grad (of your operation) -> jit (of some library call).

When debugging your code, the JIT'd library call is _probably_ not the culprit. So you only want to disable the top-level JIT when stepping through, and still take advantage of JIT compilation where you can. Overall one obtains a composition of the form grad(jit(...)).

TL;DR: even if use case doesn't come up super frequently, it's more user-friendly to support grad(jit(...)) than it is to just crash.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: