Hacker News new | past | comments | ask | show | jobs | submit login
In Defense of Pure 16-Bit Floating-Point Neural Networks (arxiv.org)
92 points by belter on May 22, 2023 | hide | past | favorite | 45 comments



The precision used should match the requirements of the dataset, the training process, and the available compute. There are practical uses to 16-Bit FP training.

"Our findings demonstrate that pure 16-bit floating-point neural networks can achieve similar or even better performance than their mixed-precision and 32-bit counterparts." This is a very deceptive statement. Take 100 initialization states and train a FP16 vs a FP32 network, and you'll find FP32 will have an accuracy advantage. It's certainly possible to conclude this if a small sample of networks are trained. This paper goes on to state, "Lowering the precision of real numbers used for the neural network’s weights to fixed-point, as shown in [11], leads to a significant decrease in accuracy.", while later concluding, "we have shown that pure 16-bit networks can perform on par with, if not better than, mixed-precision and 32-bit networks in various image classification tasks." The results certainly do, but that doesn't really give an accurate evaluation of what's really going on here. A FP64 network can fall into a local minima and be outperformed by a PF16 network, but is it correct to say the FP16 network is better. I'm getting a lot of mixed signals.

I feel like, "significant implications" is quite a stretch.

A few concerns: Besides figure 3, other results do not provide side-by-side test vs validation accuracy to attempt demonstrate the network is not overfit, and the only mention of normalization was the custom batch normalization operation.

This may more be a rant about the current state of ML, but in a perfect world, we wouldn't use GPUs/would enforce deterministic calculations, results would be replicable, we'd train hundreds if not thousands of networks to draw conclusions from, we'd better understand how to visualize network accuracies and overfitting, and all datasets would be free of bias and accurately generalize the problem attempting to be modelled. We can dream.


This is generally incorrect. FP16 matches FP32 via bfloat usually with almost no sweat, and generally any additional noise tends to have a positive regularization effect.

I train directly in _pure_ fp16/bf16 with no issues and the benefits greatly outweigh the tradeoffs. On smaller networks, I use 0 gradient clipping whatsoever.

FP32 has almost no uses outside of bizzarely intricate simulation-kinds of things, in which case FP64 is still generally important.


I appreciate your input on bfloat. I've always been under the impression that precision matters a lot when attempting to avoid local min/maxima if the landscape of the error function is jagged, but I suppose there's a good argument to be made that any floating point format can be used if the data, learning rate, network structure, etc are molded to match. Perhaps it's my perspective or maybe there actually isn't enough discourse on FP format being equally or more important factor to consider than just its affect on compute and memory requirements.

The use of FP64 could aid against vanishing gradients and just general information loss in deep networks, but that's probably comparable to using an atomic bomb to power a wind turbine. It certainly works, but is it the best way to go about it?

I personally think the use of mixed precision in deep networks will become more common as time goes on. I'm doubtful that all of a network really benefits from having large amounts of precision.


Well, if I could guide a bit in terms of focus, it's not necessarily the precision of the floating point values as much as the structure of information flow and expressivity in the network. Gradients are going to die basically regardless of precision or not, you're maybe saving yourself a few steps but if you're at the point of using precision to stave off dead gradients it's like several orders of magnitude less efficient than a decent solution is.

My personal belief on experience is that training in pure FP8 is maybe possible with some hacks, but that our limit for needing mixed precision to stabilize things might come into play around 3-6/7 bits or so (a wide range, sorry). I could be wrong though, maybe there is some really cool discrete training method out there that I'm not aware of.

A good way to prevent information loss in neural networks is to minimize all of your subpath lengths. You also want a really short shortest path for information from your first to your final layer. That will do a lot.

Also, as far as things being jagged -- remember that floating point only loses a lot of precision on large numbers, which should be really coarse anyways. Having large, perfectly precise numbers means we are likely overfitting. Small and detailed means that we can afford to have high precision. Think of it as a beneficial tradeoff like knowing momentum and/or velocity to some exchangeable extent in quantum mechanics. If we impose that on our precision, we get some nice benefits in the end.

Hope that helps sort of expound on the subject a bit more, feel free to let me know if you have any questions and much love! <3 :))) :D :)


From what I can tell the architecture is more important anyway and having smaller but more parameters gives the model more chances to figure out the optimal architecture on its own.


My best understanding is that architecture is predetermined, which determines the number of parameters up front?

I do think that, however, having shallower bit depths over time will require some slightly deeper networks to compensate, as a result. Sorta makes sense when you think about it a bit. :) <3 :DDDD :)


> This theoretical exploration offers perspective that is distinct from the literature which attributes the success of low-precision neural networks to its regularization effect

They only tease in the abstract what their "perspective" is, defining only by what it is not. I can't see it in the conclusion either. Unfortunately not up for reading through their formalisation to try and understand the point.

I have a denoising autoencoder that manages to ridiculously overfit a complex data set despite bottlenecking on a tiny set of neurons, which I attribute to it managing to exploit all the bits of precision within the bottleneck to effectively make it hold far more information than you would naively think. So I'm sceptical if they are saying this is not a real effect.


This is sort of my party trick I do in a number of places in diagnosing NNs from scare environmental data, let me give it a whirl. It sounds like your network has far too many degrees of freedom. Reducing it via residuals or some other method will likely help with that.

Additionally, some L2 weight decay, switching to SGD+OneCycle, don't forget to BatchNorm before every activation as well.

If this is a newer-style attention-Unet ala StyleGAN then that would be a confusing result as transformers seem to be pretty okay with not immediately collapsing to that kind of thing if I understand correctly.

Barring all of that, swapped labels can be a surprising reason for complex data to overfit as it forces the network into a memorization-only mode with very little chance for generalization.

Let me know if I got it correct/close for you. :) :D <3 :))))


> Remark. It is important to acknowledge that previous research has highlighted the ability of neural networks to tolerate noise and even act as a regularization method. However, our proposed analysis differs in two ways. Firstly, it is the first comparison of its kind that focuses on pure 16-bit neural networks. Secondly, our analysis offers a unique quantification approach that distinguishes it from previous explanations related to regularization. Our approach examines floating-point errors and formalizes the concept of tolerance, allowing for its quantification as a measure of comparison with floating-point errors.

> Explanation of the Lemma.... in the worst case.... But as long as... the two classifiers M16 and M32 must have the same classification result on x.

So my rough understanding is that they're saying "Hey, it's not/not only because of regularization; it's because fp16 is really as good as fp32 for what we need it to do (but more efficient?"


Isn't it known for quite a few years now that these CNN networks (ResNet, VGG etc) trains well with FP16? The problem is that attention layer with softmax can have dynamic range higher than FP16 can handle hence you have to go to BF16? I am lost in what's the novelty here in this paper.


It looks like they're formalizing the behavior of pure 16-bit training which is different from the mixed precision pipelines I'm aware of.


You can control range with a temperature, works pretty well! https://github.com/tysam-code/hlb-CIFAR10/blob/3bb104ce16d16...


Yes, it is well known in the industry that both FP16 as well as int16 have advantages. I don't see anything really new in the paper either.

It's like a lot of arXiv papers these days, as they only serve as an "Instagram for researchers".


Why stop at 16-bit? I'd be curious to see a study that tries every number of bits from 32-bit down. I see https://en.wikipedia.org/wiki/Minifloat says there is an 8-bit float which uses 4-bits for exponent and 3-bits for significand. Maybe there is a sweet spot between 8-to-16 bits for a good-enough accuracy tradeoff. Of course the hardware for that isn't standard, but maybe low-bit float hardware would be useful.


Microsoft Research published results with 1-bit gradients.

https://www.microsoft.com/en-us/research/publication/1-bit-s...

int4 (fixed point) has already been popular for inference https://developer.nvidia.com/blog/int4-for-ai-inference/ and int3 has seem some use for LLaMA-at-home


Posits seems to be better for 8-bit or even 6-bit. There is only one not-a-number place, the NaR (not a real). This means for 6-bit posits you have 63 points in the number space.

https://en.wikipedia.org/wiki/Unum_(number_format)#Unum_III


The way posits focus on numbers near 1.0 is probably going to have a bigger effect. A 6 bit float with 4 exponent bits is the best competitor to a 6 bit posit, and it would only have four non-finite numbers.


What would it mean though? If the information is embedded in the network, and we find some performance figure of merit, wouldn't it probably be about the same performance when normalized to power utilization?

Maybe it's about optimizing every clock cycle?


Why would you expect performance to be constant when normalized to power utilisation?

If your 16 bit floats perform about the same as 32 bit floats in an absolute sense, then they will probably perform even better when normalised for power utilisation.

And if 16 bit work, 15 bit floats might perform well, too, for all we know. That's what the original commenter was getting at, I think.


Yes, that is what I was getting at. Floating point hardware with fewer bits have less complexity and so have smaller transistor die area and power consumption.


If I recall correctly, area/power is proportional to the square of the length of the mantissa.


I'm guessing that "about the same" will be hard to measure, and that at some point, thermodynamics will dictate the maximum performance per power output (assuming fixed transistor architecture).


Thermodynamics will dictate the performance in some sense of bit-operations per Joule sense.

The more important performance metric is not the number of bit-operations, but the quality of the neural network output.

The hypothesis is that fewer bits in your numbers give you the same or nearly the same output quality, but at drastically fewer bit operations performed, and thus less Joule spent.


The authors should have referred to this Allen Institute paper in Nature back in 2021 which got good results with just one bit neurons! They also published earlier work in this area going back to 2016 or 2017 but this more recent, refereed paper was what I quickly found in a web search: https://www.nature.com/articles/s41586-021-03705-x

Also if you wanted you could get more resolution by just using the mantissa, not that any hardware supports that these days. I love the 1-bit work but I suspect the future is four or 8 bit mantissa between 0 and 1. Not sure you’d even need a GPU at that point, just a vector machine with a small lookup table in L1 cache.


If you just want the mantissa, why not use integers?


Yes, saturated fixed point or integer math. And I might have been hasty in saying you can’t do this with modern commodity hardware (here’s a paper: https://proceedings.mlsys.org/paper/2022/file/14bfa6bb14875e...)


"...Reducing the number of bits needed to encode the weights and activations of neural networks is highly desirable as it speeds up their training and inference time while reducing memory consumption...Our findings demonstrate that pure 16-bit floating-point neural networks can achieve similar or even better performance than their mixed-precision and 32-bit counterparts. We believe the results presented in this paper will have significant implications for machine learning practitioners, offering an opportunity to reconsider using pure 16-bit networks in various applications..."


As a practitioner specializing in extremely fast-training neural networks, seeing a paper in 2023 considering fp32 as a gold standard over pure non-mixed fp16/bp16 is a bit shocking to me and feels dated/distracting from the discussion. They make good points but unless I am hopelessly misinformed, it's been pretty well established at this point in a number of circles that fp32 is overkill for the majority of uses for many modern-day practitioners. Loads of networks train directly in bfloat16 as the standard -- a lot of the modern LLMs among them. Mixed precision is very much no longer needed, not even with fp16 if you're willing to tolerate some range hacks. If you don't want the range hacks, just use bfloat16 directly. The complexity is not worth it, adds not much at all, and the dynamic loss scaler a lot of people use is just begging for more issues.

Both of the main repos that I've published in terms of speed benchmarks train directly in pure fp16 and bf16 respectively without any fp32 frippery, if you want to see an example of both paradigms successfully feel free to take a look (I'll note that bf16 is simpler on the whole for a few reasons, generally seamless): https://github.com/tysam-code/hlb-CIFAR10 [for fp16] and https://github.com/tysam-code/hlb-gpt [for bf16]

Personally from my experience, I think fp16/bf16 is honestly a bit too expressive for what we need, fp8 seems to do just fine and I think will be quite alright with some accommodations, just as with pure fp16. The what and the how of that is a story for a different day (and at this point, the max pooling operation is basically one of the slowest now).

You'll have to excuse my frustration a bit, it just is a bit jarring to see a streetsign from way in the past fly forward in the wind to hit you in the face before tumbling on its merry way. And additionally in the comment section the general discussion doesn't seem to talk about what seems to be a pretty clearly-established consensus in certain research circles. It's not really too much of a debate anymore, it works and we're off to bigger and better problems that I think we should talk about. I guess in one sense it does justify the paper's utility, but also a bit frustrating because it normalizes the conversation as a few notches back from where I personally feel that it actually is at the moment.

We've got to move out of the past, this fp32 business to me personally is like writing a Relu-activated VGG network in Keras on Tensorflow. Phew.

And while we're at it, if I shall throw my frumpy-grumpy hat right back into the ring, this is an information-theoretic problem! Not enough discussion of Shannon and co. Let's please fix that too. See my other rants for x-references to that, should you be so-inclined to punish yourself in that manner.


I'm not an expert at all on this stuff but it seems that there are a lot of opinions floating around here on a topic that should be pretty easy to analyze with statistics. Which is supposedly something AI researchers should be very good at.

Basically what you want to know is the range and distribution of values. And then come up with efficient ways to store and encode those.

If you can go from having billions of values (32 bit) to around tens of thousands (16 bit) of values without too much penalty, that suggests 32bit is probably overkill. Also why use floats at all? Integer multiplication is cheap. Also are all values equal in importance? Is it an even distribution of values or are some ranges of values more important than others?

To me it seems that the topology of the neural networks would be a factor here. The reason for having more bits is having large numbers of incoming or outgoing connections. With only a few connections it probably matters less. But if you have thousands, noise/rounding errors might have a bigger impact. That's just my intuition for this. Again, not an expert.

My point here is that this seems a hotly debated topic but people aren't using a lot of the type of statistical arguments I would expect for that.


> Also are all values equal in importance?

Back to the Shannon question at hand (slightly answered in my next answer).

> Also why use floats at all? Integer multiplication is cheap.

Gaussianality, and they cost about the same where we're using them in current GPGPUs/tensorcores (though if Horace He steps in and corrects me on some detail of this/etc I'll gladly defer).

> are some ranges of values more important than others?

See above, also range is a good way to keep from NaNs without the overhead of NaN checking steps. Think of it as a savings account for a rainy day of capacity.

> The reason for having more bits is having large numbers of incoming or outgoing connections.

This is good intuition, though the network survives on surprisingly little precision. I had a similar feeling until one magical moment with hlb-CIFAR10 where I had to keep kicking up the quantization regularization for it to do well (for one of the older versions, at least).

> My point here is that this seems a hotly debated topic but people aren't using a lot of the type of statistical arguments I would expect for that.

I agree to a degree though in my modality of thought I would replace it with information theory since that directly informs us of a few things that we might be able to/should expect during network training. As you noted in your second to last paragraph with noise/rounding errors/etc. Which I think is good stuff.

However the empirical numbers do show pretty clearly that it works well so I'm not too sure where the need for hot debate is. RWKV is one version of a scaled model that uses it, for example. You're sort of shooting yourself in the foot with not using it these days with GPU memory being the way it is. 2x flat memory boost (for model weights) is so huge, even if it's just for I think memory transfers. Lots of networks are memory-bound these days unfortunately.

I think you have good NN-related intuition. I feel like you would find it fun to play around with (if you haven't already). Many thanks for sharing, I greatly appreciated your response. It made me think a bit, and that especially is something I value. So thank you very much for that. <3 :) :thumbsup: :thumbsup:


> The reason for having more bits is having large numbers of incoming or outgoing connections.

I am having trouble getting my head around this statement, could you please explain this more? This idea is not intuitive to me. Any example will be much appreciated.

My current thought process is this: how having more dynamic range of a single weight/parameter will help in more incoming and outgoing connections? Maybe I am approaching this statement the wrong way.

Thank you. :)


Question: Does fp16 provide more accuracy than mixed-precision? If so, any reason for this to be happening?

Looking at the discussion, everybody is agreeing to the fact that it is already well-known that fp32 is overkill and fp16(or bf16) is already industry standard(for most cases at least). But any opinions on mixed precision floating point is seems to be missing. Has anybody seen benchmarks that seem to indicate that mixed-precision fp performs worse than fp16 and fp32,(other than the paper)?


Would posits accomplish this even better, if hardware support wasn’t an issue?


A (16, 1) posit gives a maximum of 2 extra fractional bits of precision over float16, and (16, 2) gives you 1 extra bit. But also while the numerical distribution a posit assumes is slightly better but still not well geared to the distribution of values as seen in NNs in any case, it's not going to get you that much more (you hardly need anything above 10.0 because people like hyperparameters that look like 0.01 or whatever, even though with typical floating point you can multiply all the values by 10^-3 or 10^3 (or in float32, even 10^10 or 10^-10) and things would work just as well.

A posit is also more expensive in hardware, because the adders, multipliers and shifters are bigger (provisioned for the maximum fractional precision).

The trick is still in the scaling still. I've done full training in (16, 1) posit (like, everything in (16, 1), no float32 or shadow floating point values, like this paper) for some of these convnets. It doesn't work well out of the box without doing scaling tricks, and then it's ~the same as float16 with the tricks I find. It simply doesn't add that much more precision in such reduced space (that 1 or 2 extra bits at best).

What they benchmarked on here in this paper is ancient history too, not sure how these models are that relevant to modern practice these days.


The big advantage of Posit16 over Float16 is way less overflow potential. Float16s have 1 less bit than Posit16, but they also max out at 65000 which can cause a lot of overflow/NaNs. BFloats have a really big range (up to 10^38), but they only have a 7 bit mantissa. Posit16 gives you higher accuracy than BFloat16 for the entire range of Float16, while being almost as resistant to overflow as BFloat16. Yes, you lose a lot of accuracy for the big posits, but in a NN context, often for the giant values all you care about is that it's large and infinite. The hardware is a little more expensive, but it has a lot fewer special cases and you can share a lot more operations with Int hardware (although none of the expensive stuff).


Had not known about this, thanks for sharing. This does certainly seem very interesting and applicable to the problem at hand. Large values are important to keep instead of NaNs and we generally don't need the precision in my experience.

Plus, when they're that big, they're sort of a wrecking hammer to whatever weights they touch anyways, so might as well save the precision for the cleanup steps afterwards where it really counts (at whatever number of bits works best of course) :D :)))) :D :))))


For people who don't know posits (like me), https://www.cs.cornell.edu/courses/cs6120/2019fa/blog/posits... provides a good introduction, if somewhat tongue-in-cheek.

See discussion at https://news.ycombinator.com/item?id=30856623


This paper feels behind by a few years, I think. But bfloat16 and fp16 are both natively supported in hardware.

We're down to fp8 now with NVIDIA's latest hardware. This conversation is wayyyyy back from where it is in a few other places. FP8 even shouldn't be a huge issue (at least for mixed at first), it's things like the 4-bit datatypes and such where things really and truly get spicy IMO.


Somewhat because the few bits they save on things like Nan give them a sensible boost on those low precisions. But the best way to go is to explucitly fit the distribution of values you are expecting: precisely what Meta did when they introduced their own 16bits format.


What is the name of the Meta custom 16 bit format?


Here is the relevant announcement (I believe they made it available in Pytorch as Bfloat16):

https://engineering.fb.com/2018/11/08/ai-research/floating-p...


FP8 or even a Frankenstein bf8 should be even better


https://github.com/NVIDIA/TransformerEngine - mixed precision FP8 is already here and provides similar accuracy to FP16/BF16


Yeah and for some reason limited to only hardware support of H100. Even the cost of doing it in software is outweighed by the speed & storage gains from it


Yeah, I'd expect the most value to come out of 2 bit mantissa 3 bit exponent 1 bit sign setting.

Actually could probably try it on its own: a lookup table for binary ops over 6 bit arguments is just 6/8 of 4kb.




Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: