-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
131 lines (106 loc) · 4.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""The main script for training the model. Take in the raw data and output a trained model."""
import os
import warnings
from contextlib import nullcontext
from pathlib import Path
import hydra
import wandb
from epochalyst.logging.section_separator import print_section_separator
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import DictConfig
from src.config.train_config import TrainConfig
from src.setup.setup_data import setup_train_x_data, setup_train_y_data
from src.setup.setup_pipeline import setup_pipeline
from src.setup.setup_runtime_args import setup_train_args
from src.setup.setup_wandb import setup_wandb
from src.typing.typing import XData
from src.utils.lock import Lock
from src.utils.logger import logger
from src.utils.set_torch_seed import set_torch_seed
warnings.filterwarnings("ignore", category=UserWarning)
# Makes hydra give full error messages
os.environ["HYDRA_FULL_ERROR"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# Set up the config store, necessary for type checking of config yaml
cs = ConfigStore.instance()
cs.store(name="base_train", node=TrainConfig)
@hydra.main(version_base=None, config_path="conf", config_name="train")
def run_train(cfg: DictConfig) -> None:
"""Train a model pipeline with a train-test split. Entry point for Hydra which loads the config file.
:param cfg: The config object. Created with Hydra.
"""
# Run the train config with an optional lock
optional_lock = Lock if not cfg.allow_multiple_instances else nullcontext
with optional_lock():
run_train_cfg(cfg)
def run_train_cfg(cfg: DictConfig) -> None:
"""Train a model pipeline with a train-test split.
:param cfg: The config object. Created with Hydra.
:raise ValueError: If test size is 0 and n_splits is not 0.
"""
print_section_separator("Q4 - BirdCLEF - Training")
import coloredlogs
coloredlogs.install()
# Set seed
set_torch_seed()
# Get output directory
output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
if cfg.wandb.enabled:
setup_wandb(cfg, "train", output_dir)
# Preload the pipeline
print_section_separator("Setup pipeline")
model_pipeline = setup_pipeline(cfg)
# Cache arguments for x_sys
processed_data_path = Path(cfg.processed_path)
processed_data_path.mkdir(parents=True, exist_ok=True)
cache_args = {
"output_data_type": "numpy_array",
"storage_type": ".pkl",
"storage_path": f"{processed_data_path}",
}
# Read the data if required and split it in X, y
x_cache_exists = model_pipeline.get_x_cache_exists(cache_args)
# y_cache_exists = model_pipeline.get_y_cache_exists(cache_args)
X: XData | None = None
if not x_cache_exists:
X = setup_train_x_data(cfg.raw_path, cfg.years, cfg.max_recordings_per_species)
# If not cache exists, we need to load the data
y = setup_train_y_data(cfg.raw_path, cfg.years, cfg.max_recordings_per_species)
# For this simple splitter, we only need y.
if cfg.test_size == 0:
if cfg.splitter.n_splits != 0:
raise ValueError("Test size is 0, but n_splits is not 0. Also please set n_splits to 0 if you want to run train full.")
logger.info("Training full.")
train_indices: dict[str, list[int]] = {year: list(range(len(X[f"bird_{year}"]))) for year in cfg.years} # type: ignore[index]
test_indices: dict[str, list[int] | dict[str, list[int]]] = {year: [] for year in cfg.years}
fold = -1
else:
logger.info("Using splitter to split data into train and test sets.")
train_indices, test_indices = next(instantiate(cfg.splitter).split(y))
fold = 0
logger.info(f"Train/Test size: {[len(year_indices) for year_indices in train_indices.values()]}/{[len(year_indices) for year_indices in test_indices.values()]}")
print_section_separator("Train model pipeline")
train_args = setup_train_args(
pipeline=model_pipeline,
cache_args=cache_args,
train_indices=train_indices,
test_indices=test_indices,
save_model=True,
fold=fold,
save_model_preds=True,
)
predictions, y_new = model_pipeline.train(X, y, **train_args)
if not y:
y = y_new
if sum(len(test_indices[year]) for year in test_indices) > 0:
print_section_separator("Scoring")
scorer = instantiate(cfg.scorer)
score = scorer(y, predictions, test_indices=test_indices, years=cfg.years, output_dir=output_dir)
logger.info(f"Score: {score}")
if wandb.run:
[wandb.log({f"Score_{year}_0": score[year]}) for year in score] if isinstance(score, dict) else wandb.log({"Score": score})
wandb.log({"Score": score["2024"]}) if isinstance(score, dict) and "2024" in score else None
wandb.finish()
if __name__ == "__main__":
run_train()