Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dice Loss resulting in unexpected logit outputs #900

Open
trchudley opened this issue Jul 31, 2024 · 2 comments
Open

Dice Loss resulting in unexpected logit outputs #900

trchudley opened this issue Jul 31, 2024 · 2 comments

Comments

@trchudley
Copy link

trchudley commented Jul 31, 2024

Hi, thanks for such a great package!

I'm currently training a simple U-net on a binary image classification problem (in this case, identifying water in satellite imagery). I am exploring different segmentation_models_pytorch loss functions to train the model. I have been using a focal loss function to initially build and test the model:

loss_function = smp.losses.FocalLoss(mode="binary", alpha=alpha, gamma=gamma)

After training the model, this model produces a relatively sensible output in both raw logit values and as a probability value after a sigmoid activation layer is applied:

focalloss

(NB: this is just a test run, the final accuracy doesn't really matter at this point)

I have been looking to explore and switch to dice loss. My understanding is that, using the from_logits variable, I could simply drop-and-replace the FocalLoss class with DiceLoss as follows:

loss_function = smp.losses.DiceLoss(mode="binary", from_logits=True)

Training the model using this DiceLoss class results in the following when applied to the same image:

diceloss

Looking at the logit output, this is great - the new dice-loss-trained model appears to qualitatively perform even better than the focal-loss-trained model! However, the raw outputs are not scaled around zero any more. Instead, raw outputs are all positive, scaled between approximately ~400 and ~9000 (depending on what image the model is being applied). As a result, applying a sigmoid activation does not create a nice probability distribution between zero and one - instead, the apparent probabilities are all now 1, due to the all-positive logit distribution.

I've examined the source code and I can't see anything that would result in such a difference. Am I missing something that results in DiceLoss not being a drag-and-drop replacement for FocalLoss to create probabilistic model predictions?

@qubvel
Copy link
Collaborator

qubvel commented Jul 31, 2024

Hi @trchudley, it should be a drop-in replacement. If from_logits=True with binary mode is specified sigmoid function is used inside the loss function. It's an interesting observation, you can see that for example notebook for binary segmentation DiceLoss function works fine and predicted logits are distributed around 0.

@trchudley
Copy link
Author

Thanks @qubvel. Yes, my understanding from looking at the code was that FocalLoss accepts logit inputs, and DiceLoss could accept logits with from_logits set to True, so hopefully this lets them accept the same (logit) input.

Thanks for linking the notebook - I will have a detailed look and see whether there's anything I might be missing. I'll get back to you...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants