Hacker News new | past | comments | ask | show | jobs | submit login

how did you strip out weighs for other classes? what does that even mean?



So this will depend a bit on the architecture (I was working with CenterNet). In my situation, the final feature maps for each class are all obtained by a series of network "heads" that perform a set of convolutions on the same set of slightly deeper set of featuremaps. Each convolutional "head" is responsible for object-detections for a single class. So in the case of COCO, if you have 80 classes, you have 80 such heads. In this situation, if I wanted to create a new CenterNet model that predicted (let's say) the following 3 classes: "person", "teapot" and "donkey", only the class "person" exists in COCO and I already have a wonderfully robust person detector if I can utilize the person detecting weights from the COCO model.

So what I can do is, that I can instantiate a CenterNet model of identical architecture, except with only 3 heads for the 3 classes instead of 80 heads for the 80 COCO classes. Now, when I try to load the COCO weights in, there will be a mismatch and typically you end up with the heads being left with their default initialization while the rest of the backbone gets the COCO weights... this is the traditional way you do transfer learning on related problems because you are still starting off with a much better set of weights for your entire network backbone than random weights, which will help with training on related tasks.

However, we can go a step further and load up the state_dict from the COCO weights file, figure out which set of weights are for the "person" head and assign them to let's say the 1st of your 3 heads in your new architecture. You can even go a step further... Since the "donkey" class is quite similar to the "horse" class in COCO, you could also transfer the weights for the "horse" head in the COCO weights to your 2nd head. So now you have a network with 2 of the 3 heads already set up to be robust person and horse detectors. These are much better poised to then be fine-tuned on your application specific data for examples of people and donkeys. You end up with a model that is much more robust on those 2 classes despite only having (let's say) a couple of hundred labeled images for your specific application.

Hope all of this made sense. It's just nice that in Pytorch, everything is pretty straightforward and weights are just dicts and it's super easy to introspect them and splice them, etc.


Very nice write up! I'd be very curious if you had any links/resources related to your solution, or that helped you come up with it.


Thanks. I don't really have anything to link to. It's just something I decided to attempt as it made sense to me to try it. Person detection is notoriously hard to get right on a small set of data, but when trained on something like COCO, modern person detectors really feel like magic because they are so amazingly good at detecting people even in the weirdest of poses. So I wanted to try and leverage the robustness of a COCO-trained person detector and then fine-tune it for my problem but still have other new classes in my network and this seemed like the way to go about it.


This is brilliant




Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: