I wouldn't completely discount it. Like I wanted to speed up a certain loss function by rewriting it in triton. But I have to manually code the backward pass function for it in triton. For that I need to know how to calculate the derivative. Then again I guess with pytorch 2.0 compile function it might not be necessary