Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The issue of query loss not converging #5

Open
cityofmountain opened this issue Dec 2, 2024 · 2 comments
Open

The issue of query loss not converging #5

cityofmountain opened this issue Dec 2, 2024 · 2 comments

Comments

@cityofmountain
Copy link

MD-DETR/engine.py

Lines 95 to 115 in 125e771

if self.args.use_prompts:
with torch.no_grad():
outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels, train=False, task_id=self.task_id)
if not self.args.local_query:
query = outputs.last_hidden_state.mean(dim=1)
else:
query = outputs.last_hidden_state
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.model.matcher(outputs_without_aux, labels)
one_hot_proposals = torch.zeros((len(labels),300)).to(self.device)
for i,ind in enumerate(indices):
for j in ind[0]:
one_hot_proposals[i][j] = 1
query_wt = self.model.model.prompts.query_tf(query.view(query.shape[0],-1))
query_loss = F.cross_entropy(query_wt, one_hot_proposals)

I found that the query loss is within the torch.no_grad(): block, which prevents the query loss from being included in the computation graph. As a result, the query loss cannot converge.

image

@cityofmountain
Copy link
Author

We also found that the query tf fully connected layer is quite large, with nearly 90M parameters, which is more than 10 times the size of the prompt, and undoubtedly unacceptable.

@cityofmountain
Copy link
Author

cityofmountain commented Dec 9, 2024

I tried modifying the query_tf used for weighted averaging of queries, but encountered some strange phenomena. This module was not trained, yet deleting it caused a drop in precision. Furthermore, training this module also led to a decrease in precision.

The original output precision was: task123map0.5 = 0.52, task4map0.5 = 0.52.
image
image

1、After directly deleting it, the precision of the old tasks dropped: task123map0.5 = 0.33, task4map0.5 = 0.5.
image
image

2、Modifying the training logic of query_tf so that it gets recorded in the computation graph and has gradients caused the precision of both new and old tasks to drop: task123map0.5 = 0.18, task4map0.5 = 0.43.
image
image

3、Detaching the input of query_tf and only training query_tf itself improved the training for the new task, suggesting that this loss affected the optimization of other losses. Moreover, more precise weight prediction did not help maintain the precision for the old tasks and actually had a negative impact: task123map0.5 = 0.29, task4map0.5 = 0.5.
image
image

4、Keeping query_tf but removing the code that calculates query_loss resulted in almost no change in the published precision: task123map0.5 = 0.52, task4map0.5 = 0.48.
image
image

This is quite perplexing. A module that wasn't trained still causes a significant drop in precision when deleted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant