Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spam dataloader #37

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Prateek0xeo
Copy link

This PR refactors the spam_dataloader function and introduces robust unit tests to ensure it meets expected functionality.

Results:

All test cases pass, confirming expected behavior of the dataloader under various scenarios.
Ensures compatibility with the Spam dataset structure and Torch requirements.

Copy link
Member

@dfalbel dfalbel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Prateek0xeo !
I added some review comments.

R/spam-dataloader.R Outdated Show resolved Hide resolved
R/spam-dataloader.R Outdated Show resolved Hide resolved
@Prateek0xeo
Copy link
Author

Prateek0xeo commented Jan 15, 2025

@dfalbel Thank you for the feedback!

  1. As requested I have removed the library(torch) from the PR

  2. Now I am allowing the user to define how they want to create the dataloader by keeping the dataset definition separate from the dataloader creation.
    By returning a torch::dataset object, the user has full control to decide:

  • Batch size.
  • Whether to shuffle the data.
  • Number of parallel workers.

ds <- spam_dataset(download = TRUE)
trying URL 'https://hastie.su.domains/ElemStatLearn/datasets/spam.data'
Content type 'unknown' length 698341 bytes (681 KB)
downloaded 681 KB

loader <- torch::dataloader(
dataset = ds,
batch_size = 32,
shuffle = TRUE,
num_workers = 4
)
batch <- dataloader_make_iter(loader) %>% dataloader_next()
dim(batch$x)
[1] 32 57
length(batch$y)
[1] 32

If additional modifications are needed, please let me know!

Copy link
Member

@dfalbel dfalbel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Prateek0xeo

Thanks for updating the PR! I added a couple more comments.
Can you also add the dataset to the Readme table?

@@ -0,0 +1,16 @@
if (requireNamespace("testthat", quietly = TRUE)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this require statement here. devtools::test() will make sure testthat is laoded.


test_that("spam_dataloader works as expected", {

loader <- spam_dataloader(download = TRUE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably need to update the test cases, as spam_dataloader is now called spam_dataset.

Comment on lines 59 to 60
x = torch_tensor(x, dtype = torch_float()),
y = torch_tensor(y, dtype = torch_long())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is likely to work in most scenarios because torch would already be loaded but in any case, we should prefix the call with:

Suggested change
x = torch_tensor(x, dtype = torch_float()),
y = torch_tensor(y, dtype = torch_long())
x = torch::torch_tensor(x, dtype = torch_float()),
y = torch::torch_tensor(y, dtype = torch_long())

@Prateek0xeo
Copy link
Author

@dfalbel I have updated the PR in accordance to the new requests made

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants