I appreciate your input on bfloat. I've always been under the impression that precision matters a lot when attempting to avoid local min/maxima if the landscape of the error function is jagged, but I suppose there's a good argument to be made that any floating point format can be used if the data, learning rate, network structure, etc are molded to match. Perhaps it's my perspective or maybe there actually isn't enough discourse on FP format being equally or more important factor to consider than just its affect on compute and memory requirements.
The use of FP64 could aid against vanishing gradients and just general information loss in deep networks, but that's probably comparable to using an atomic bomb to power a wind turbine. It certainly works, but is it the best way to go about it?
I personally think the use of mixed precision in deep networks will become more common as time goes on. I'm doubtful that all of a network really benefits from having large amounts of precision.
Well, if I could guide a bit in terms of focus, it's not necessarily the precision of the floating point values as much as the structure of information flow and expressivity in the network. Gradients are going to die basically regardless of precision or not, you're maybe saving yourself a few steps but if you're at the point of using precision to stave off dead gradients it's like several orders of magnitude less efficient than a decent solution is.
My personal belief on experience is that training in pure FP8 is maybe possible with some hacks, but that our limit for needing mixed precision to stabilize things might come into play around 3-6/7 bits or so (a wide range, sorry). I could be wrong though, maybe there is some really cool discrete training method out there that I'm not aware of.
A good way to prevent information loss in neural networks is to minimize all of your subpath lengths. You also want a really short shortest path for information from your first to your final layer. That will do a lot.
Also, as far as things being jagged -- remember that floating point only loses a lot of precision on large numbers, which should be really coarse anyways. Having large, perfectly precise numbers means we are likely overfitting. Small and detailed means that we can afford to have high precision. Think of it as a beneficial tradeoff like knowing momentum and/or velocity to some exchangeable extent in quantum mechanics. If we impose that on our precision, we get some nice benefits in the end.
Hope that helps sort of expound on the subject a bit more, feel free to let me know if you have any questions and much love! <3 :))) :D :)
The use of FP64 could aid against vanishing gradients and just general information loss in deep networks, but that's probably comparable to using an atomic bomb to power a wind turbine. It certainly works, but is it the best way to go about it?
I personally think the use of mixed precision in deep networks will become more common as time goes on. I'm doubtful that all of a network really benefits from having large amounts of precision.