Hacker News new | past | comments | ask | show | jobs | submit login

I think JAX is cool, but I do find it slightly disingenuous when it claims to be "numpy by on the GPU" (as opposed to PyTorch), when actually there's a fundamental difference; it's functional. So if I have an array `x` and want to set index 0 to 10, I can't do:

  x[0] = 10
Instead I have to do:

  y = x.at[0].set(10)
Of course this has advantages, but you can't then go and claim that JAX is a drop in replacement for numpy, because this such a fundamental change to how numpy developers think (and in this regard, PyTorch is closer to numpy than JAX).



Agree, though I wouldn’t call PyTorch close to a drop-in for NumPy either, there are quite some mismatches in their APIs. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. E.g. Thinc’s ops work with both NumPy and CuPy:

https://github.com/explosion/thinc/blob/master/thinc/backend...

Though I guess the question is why one would still use NumPy when there are good libraries for CPU and GPU. Maybe for interop with other libraries, but DLPack works pretty well for converting arrays.


Why is that? Why doesn't Jax just do something like

    class JaxWrapper:
        def __init__(self, arr):
            self.arr = arr
        def __setitem__(self, key, val):
            return self.arr.at[key].set(val)
        ....


On the other hand, if I wanted some scientific NumPy code to run on the GPU, I think rewriting it in JAX would probably be a better choice than PyTorch.


In my experience, the answer comes down to "does your code use classes liberally?"

If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.


I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)

You might like Equinox (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.

Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.


Sorry for my potentially VERY ignorant question, I only know functional programming at average joe level.

Why can't you do the first in functional programming (not in this specific case because it's just how it is, but in general)?

And even if you can't do so for any reasonable reason in functional (again, in general), what stops us to just add syntactic sugar to equal it to the second to make programmer's life easier?


There's 2 different aspects people mean when they call sth functional programming:

- higher order functions (lambdas, currying, closures, etc.)

- pure functions, immutability by default, side effects are pushed to the top level and marked clearly

The first aspect of functional programming has been already accepted by most OOP languages (even C++ has lambdas and closures).

The second aspect of functional programming is what makes it useful on GPU (because GPU architecture that makes it so powerful requires no interactions between code fragments that are run in parallel on 1000s of cores). So you can easily run pure functional code on GPU, but you can't easily run imperative code on GPU.

You can introduce side effects to functional programming, but then it ceases to be any more useful for GPU (and other parallel programming) than imperative/OOP.


The fundamental reason why many functional languages won't allow you to do the first is that they use immutable data structures.

We could indeed introduce syntactic sugar (`y= (x[0]:=10)` maybe), but you'll still need to introduce a new variable to hold the modified list.


It's my understanding that, at least in Python, you can't change immutable data type but you can just assign a new data to the same variable and therefore overwrite it, right? So even if JAX makes list type immutable, you can still just re-use `x` to save the new modified list.


doesn't `[] =` just call a method on the object in python?

e.g, `x[0] = 10` is the same as `x.__set_item__(0, 10)`, so there shouldn't be any technical limitation to using `x[0]` (says the guy who never even imported jax)


You could do `y = x.__setitem__(0, 10)`, but you cannot assign `x[0] = 10` to a new variable. If `__setitem__` was overridden, you would not be able to distinguish between these cases and raise an error in the second one.


Yes, that makes perfect sense.

I somehow completely missed the assignment part of the second example.

Thank you for the clarification.


Also conditionals can be tricky (greater, if else) and often need rewriting.




Join us for AI Startup School this June 16-17 in San Francisco!

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: