Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue: Challenges in Using semantic_sam_trainer for Fine-Tuning Semantic Segmentation #847

Open
adityajbir opened this issue Feb 4, 2025 · 2 comments

Comments

@adityajbir
Copy link

I successfully implemented the instance segmentation function provided in the Micro-SAM repository. However, while using the semantic_sam_trainer function to fine-tune the model for semantic segmentation on custom images, I encountered several issues. Below, I detail the problems, fixes made to the source code, and the remaining unresolved issue.

Issue 1: RuntimeError: "host_softmax" not implemented for 'Bool'

Problem: When computing the Dice Loss using a custom loss function, the following error occurred:

RuntimeError: "host_softmax" not implemented for 'Bool'

This happened because the prediction tensor was treated as a boolean type, but the torch.softmax function requires a floating-point input.

Fix: To address this, I explicitly converted the pred tensor to a floating-point type before applying torch.softmax.

if self.softmax:
    pred = torch.softmax(pred.float(), dim=1)

Issue 2: ValueError: Expected input and target of same shape

Problem: When comparing the input tensor against class indices in _one_hot_encoder, the output shape of the tensor was [B, H*num_classes, W] instead of the expected [B, num_classes, H, W]. This caused a mismatch in shapes during the loss computation.

Fix: I modified the _one_hot_encoder function to unsqueeze the tensor along the channel dimension (axis 1) to align the output with the expected shape.
Modified Code:
Before:

temp_prob = input_tensor == i
tensor_list.append(temp_prob)

After:

temp_prob = (input_tensor == i).unsqueeze(1)  # Shape: [B, 1, H, W]
tensor_list.append(temp_prob)

Final concatenation:

output_tensor = torch.cat(tensor_list, dim=1)  # Shape: [B, num_classes, H, W]
return output_tensor.float()

Issue 3: RuntimeError: "host_softmax" not implemented for 'Bool' (Recurrence)

Problem: The softmax error reappeared due to the masks tensor being of boolean type. This issue arose inside the _compute_loss function.

Fix: I converted the masks tensor to a floating-point type within _compute_loss.
Modified Code:

masks = masks.float()

Issue 4: AssertionError: Class number out of range

Problem: While running the code with the assumption of 3 classes, the following error occurred:

Assertion `t >= 0 && t < n_classes` failed.

This indicates that one or more pixels in the target tensor had values outside the valid range of [0, num_classes-1]. This particular error occurred when it was run with 3 classes

Debugging Steps:

Verified the groud truth masks had 4 classes in the dataset.
Updated the code to handle 4 classes. However, this led to a shape mismatch error in the Dice Loss computation.

Issue 5: ValueError: Expected input and target of same shape

Problem: After resolving the previous issues, a ValueError was raised:

ValueError: Expected input and target of same shape, got: torch.Size([2, 3, 488, 685]), torch.Size([2, 4, 488, 685]).

This occurred because the input tensor had 3 channels, while the target tensor had 4 channels.

@anwai98
Copy link
Contributor

anwai98 commented Feb 5, 2025

Hi @adityajbir,

Thanks for your interest in micro-sam.

Before I look into the details and try to reproduce the issues, could you elaborate on the problem statement?
i.e. what are your input images, what are the corresponding labels, and what is the expected outcome?

This would be a good starting point for us to discuss further details!

@adityajbir
Copy link
Author

adityajbir commented Feb 6, 2025

Hi @anwai98 ,

For the input images I am using images of cells as pngs. For the mask I am using Labelme(annotation software) to first add labels to those input images. Those files are treated as json files and then I convert the json files back into the mask as a png. For the labels, we should have 4 classes, Background and 3 cell types. For the expected outcome, currently I am trying to get the training phase of the code working where I pass in all the necessary parameters to the SemanticSamTrainer. Once the training of the model is finished, I would evaluate the performance of the model against a test set and visualize it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants