Skip to content

Commit

Permalink
attempt to handle data sources which require the dataset_name attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Jan 20, 2024
1 parent 7dc2115 commit 435f35e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions tableshift/core/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import ray.data
import torch
from pandas import DataFrame, Series
from tableshift.third_party.domainbed import InfiniteDataLoader
from torch.utils.data import DataLoader

from tableshift.third_party.domainbed import InfiniteDataLoader
from .features import Preprocessor, PreprocessorConfig, is_categorical
from .grouper import Grouper
from .metrics import metrics_by_group
Expand Down Expand Up @@ -201,10 +201,17 @@ def __init__(self, name: str, config: DatasetConfig,
# Dataset-specific info: features, data source, preprocessing.

self.task_config = get_task_config(self.name) if task_config is None else task_config
self.data_source = self.task_config.data_source_cls(
cache_dir=self.config.cache_dir,
download=self.config.download,
**kwargs)
try:
self.data_source = self.task_config.data_source_cls(
cache_dir=self.config.cache_dir,
download=self.config.download,
**kwargs)
except TypeError:
kwargs.update({"dataset_name": self.name})
self.data_source = self.task_config.data_source_cls(
cache_dir=self.config.cache_dir,
download=self.config.download,
**kwargs)

self.preprocessor = Preprocessor(
config=self.preprocessor_config,
Expand Down

0 comments on commit 435f35e

Please sign in to comment.