Skip to content

Commit

Permalink
fix task index
Browse files Browse the repository at this point in the history
  • Loading branch information
Damien Sileo committed Apr 24, 2023
1 parent 083a9b3 commit e5908d9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
8 changes: 2 additions & 6 deletions src/tasknet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,8 @@ def __init__(self, tasks):
def __call__(
self, features: List[Union[InputDataClass, Dict]]
) -> Dict[str, torch.Tensor]:
#try:
task_index = features[0]["task"].flatten()[0].item()
#except:
# print("features:",features)
# task_index = features[-1]["task"].flatten()[0].item()


task_index = features[0]["task"].flatten()[0].item()
features = [{k:v for k,v in x.items() if k!='task'} for x in features]
collated = self.tasks[task_index].data_collator.__call__(features)
collated['task']=torch.tensor([task_index])
Expand Down
4 changes: 3 additions & 1 deletion src/tasknet/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def sample_dataset(dataset,n=10000, n_eval=1000, oversampling=None):

def get_len(outputs):
try:
return len(outputs[fc.first(outputs)])
batch_length=len(outputs[fc.first(outputs)])
assert batch_length
return batch_length
except:
return 1

Expand Down

0 comments on commit e5908d9

Please sign in to comment.