-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix combined dataloader bug #163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new CombinedDataLoader
class that combines the old CombinedIterableDataset
class and combine_dataloaders
function seems useful.
I have some questions regarding its logic. Explenations and unit tests are missing but we can do this once it is working as we want.
Combines multiple dataloaders into a single iterable dataset. | ||
This is useful for combining multiple dataloaders into a single | ||
dataloader. The new dataloader can be shuffled or not. | ||
|
||
:param dataloaders: list of dataloaders to combine | ||
:param shuffle: whether to shuffle the combined dataloader | ||
|
||
:return: combined dataloader | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the docstring was useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad
def __len__(self): | ||
return sum(len(dl) for dl in self.dataloaders) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will change the behavior from before? Isn't this weird that if you ask for the len
of a combined dataloader and you get the summed length of all elements? Usually a length of a list of arrays returns the number of arrays and not the sum the length of arrays...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's just a matter of definitions. In this case, I think it should return the number of batches that the dataloader returns in one epoch to match the behavior of torch dataloaders. The number of batches that it returns is the sum of the number of batches in the individual dataloaders
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. But then we have to comment this at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a sum is the behavior we want (the length of a torch dataset is always the total number of elements inside you can index) if we decide to hide the fact that this is a combined dataset, i.e. if the user should use it as
for batch in dataset:
# do stuff
However, my understanding was that the batches would not mix stuff from different datasets, is this right? if so, what happens when combining a dataset of size 7 and one of size 11 and trying to use a batch size of 10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a combined dataloader, not dataset. It takes many dataloaders and returns the individual batches (shuffled or not, depending on what the caller wants)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, and then the length of the combined dataloader is how many batches it will produce. Makes sense.
|
||
def __iter__(self): | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this working? Shouldn't you return the next dataloader here instead of the full instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does work. It comes from ChatGPT, but intuitively it makes sense:
By iterating, you effectively call iterable = iter(dataloader)
and then next(iterable)
(a bunch of times). Next is defined in the class, so it makes sense to return self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay then please add a test for iterator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's there already from the old function + class
The tests are the same as those that were there before, but I've added an additional case that catches the bug. I'll add the docstring |
Are you able to fox the tests? If this is done I take another look at the code. |
Yes, at some point. It's just the regression tests which always break (in this case due to the different batches being created by the new dataloader) |
42cc4c7
to
626094c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good, just two minor doc comments.
I still don't understand why the regression tests are this fragile...
"""Creates the combined dataloader. | ||
|
||
:param dataloaders: list of dataloaders to combine | ||
:param shuffle: whether to shuffle the combined dataloader (this does not | ||
act on the individual batches, but it shuffles the order in which | ||
they are returned) | ||
|
||
:return: the combined dataloader | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be no docstring for the init. Everything should be in the class docstring.
# this returns the total number of batches in all dataloaders | ||
# (as opposed to the total number of samples or the number of | ||
# individual dataloaders) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be visible mostly users -> move to docstrings
Once I am reading the code I don't need this information because the lines below already tells me this.
Before this, the combined dataloader was returning repeated samples within the same epoch
📚 Documentation preview 📚: https://metatensor-models--163.org.readthedocs.build/en/163/