Skip to content

Commit

Permalink
Fixed rollout_*_batcher.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kanz76 committed Feb 10, 2025
1 parent 58c2c13 commit db62743
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 91 deletions.
10 changes: 2 additions & 8 deletions applications/rollout_ens_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
from credit.parser import credit_main_parser, predict_data_check
from credit.datasets.era5_predict_batcher import (
BatchForecastLenDataLoader,
Predict_Dataset_Batcher
)

from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -365,9 +362,6 @@ def predict(rank, world_size, conf, backend=None, p=None):
y_pred = None
gc.collect()

if distributed:
torch.distributed.barrier()

forecast_count += batch_size

if distributed:
Expand Down
9 changes: 3 additions & 6 deletions applications/rollout_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
from credit.parser import credit_main_parser, predict_data_check
from credit.datasets.era5_predict_batcher import (
BatchForecastLenDataLoader,
Predict_Dataset_Batcher,
)
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
from credit.ensemble.bred_vector import generate_bred_vectors
from credit.ensemble.crps import calculate_crps_per_channel

Expand Down Expand Up @@ -549,8 +547,7 @@ def predict(rank, world_size, conf, backend=None, p=None):
y_pred = None
gc.collect()

if distributed:
torch.distributed.barrier()


forecast_count += batch_size

Expand Down
62 changes: 28 additions & 34 deletions applications/rollout_metrics_batcher.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,44 @@
# ---------- #
# System
import os
import gc
import sys
import yaml
import logging
import warnings
import multiprocessing as mp
from pathlib import Path
import os
import sys
import warnings
from argparse import ArgumentParser
from collections import defaultdict

# ---------- #
# Numerics
from datetime import datetime, timedelta
import pandas as pd
import xarray as xr
from pathlib import Path

import numpy as np
import pandas as pd

# ---------- #
import torch

import xarray as xr
import yaml
# ---------- #
# credit
from credit.models import load_model
from credit.seed import seed_everything
from credit.data import concat_and_reshape, reshape_only
from credit.datasets import setup_data_loading
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.metrics import LatWeightedMetrics, LatWeightedMetricsClimatology
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
from credit.distributed import distributed_model_wrapper, get_rank_info, setup
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper, setup, get_rank_info
from credit.metrics import LatWeightedMetrics, LatWeightedMetricsClimatology

from credit.models import load_model
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
from credit.parser import credit_main_parser, predict_data_check
from credit.datasets.era5_predict_batcher import (
BatchForecastLenDataLoader,
Predict_Dataset_Batcher,
)

from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.postblock import GlobalEnergyFixer, GlobalMassFixer, GlobalWaterFixer
from credit.seed import seed_everything
from credit.transforms import Normalize_ERA5_and_Forcing, load_transforms

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -146,6 +144,10 @@ def predict(rank, world_size, conf, backend=None, p=None):

# Load the forecasts we wish to compute
forecasts = load_forecasts(conf)
if len(forecasts) % world_size != 0:
raise ValueError(
f'Number of forecast inits ({len(forecasts)}) given by conf["predict"]["duration"] x len(conf["predict"]["start_hours"]) should be divisible by number of processes/GPUs ({world_size})'
)

dataset = Predict_Dataset_Batcher(
varname_upper_air=data_config["varname_upper_air"],
Expand Down Expand Up @@ -229,8 +231,6 @@ def predict(rank, world_size, conf, backend=None, p=None):

# y_pred allocation and results tracking
results = []
save_datetimes = [0] * batch_size

# model inference loop
for k, batch in enumerate(data_loader):
batch_size = batch["datetime"].shape[0]
Expand All @@ -248,9 +248,6 @@ def predict(rank, world_size, conf, backend=None, p=None):
)
for i in range(batch_size)
]
save_datetimes[forecast_count : forecast_count + batch_size] = (
init_datetimes
)
if "x_surf" in batch:
x = (
concat_and_reshape(batch["x"], batch["x_surf"])
Expand Down Expand Up @@ -353,7 +350,7 @@ def predict(rank, world_size, conf, backend=None, p=None):
results.append((j, result)) # Store the batch index with the result

# Print to screen
print_str = f"Forecast: {forecast_count + 1 + j} "
print_str = f"{rank=:} Forecast: {forecast_count + 1 + j} "
print_str += f"Date: {utc_datetime[j].strftime('%Y-%m-%d %H:%M:%S')} "
print_str += f"Hour: {forecast_step * lead_time_periods} "
print(print_str)
Expand Down Expand Up @@ -401,13 +398,10 @@ def predict(rank, world_size, conf, backend=None, p=None):
y_pred = None
gc.collect()

if distributed:
torch.distributed.barrier()

forecast_count += batch_size

if distributed:
torch.distributed.barrier()
torch.distributed.destroy_process_group()

return 1

Expand Down Expand Up @@ -517,9 +511,9 @@ def predict(rank, world_size, conf, backend=None, p=None):
data_config = setup_data_loading(conf)

# create a save location for rollout
assert (
"save_forecast" in conf["predict"]
), "Please specify the output dir through conf['predict']['save_forecast']"
assert "save_forecast" in conf["predict"], (
"Please specify the output dir through conf['predict']['save_forecast']"
)

forecast_save_loc = conf["predict"]["save_forecast"]
os.makedirs(forecast_save_loc, exist_ok=True)
Expand Down
79 changes: 36 additions & 43 deletions applications/rollout_to_netcdf_batcher.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,43 @@
import os
import gc
import sys
import yaml
import logging
import multiprocessing as mp
import os
import sys
import warnings
from pathlib import Path
from argparse import ArgumentParser
import multiprocessing as mp

# ---------- #
# Numerics
from datetime import datetime, timedelta
import xarray as xr
from pathlib import Path

import numpy as np

# ---------- #
import torch

import xarray as xr
import yaml
# ---------- #
# credit
from credit.models import load_model
from credit.seed import seed_everything
from credit.distributed import get_rank_info
from credit.datasets import setup_data_loading
from credit.datasets.era5_predict_batcher import (
BatchForecastLenDataLoader,
Predict_Dataset_Batcher,
)

from credit.data import (
concat_and_reshape,
reshape_only,
)

from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.datasets import setup_data_loading
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
from credit.distributed import distributed_model_wrapper, get_rank_info, setup
from credit.forecast import load_forecasts
from credit.distributed import distributed_model_wrapper, setup

from credit.models import load_model
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
from credit.parser import credit_main_parser, predict_data_check
from credit.output import load_metadata, make_xarray, save_netcdf_increment
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
from credit.parser import credit_main_parser, predict_data_check
from credit.pbs import launch_script, launch_script_mpi
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
from credit.postblock import GlobalEnergyFixer, GlobalMassFixer, GlobalWaterFixer
from credit.seed import seed_everything
from credit.transforms import Normalize_ERA5_and_Forcing, load_transforms

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -134,9 +130,9 @@ def predict(rank, world_size, conf, p):

# Load the forecasts we wish to compute
forecasts = load_forecasts(conf)
if len(forecasts) < batch_size:
logger.warning(
f"number of forecast init times {len(forecasts)} is less than batch_size {batch_size}, will result in under-utilization"
if len(forecasts) % world_size != 0:
raise ValueError(
f'Number of forecast inits ({len(forecasts)}) given by conf["predict"]["duration"] x len(conf["predict"]["start_hours"]) should be divisible by number of processes/GPUs ({world_size})'
)

dataset = Predict_Dataset_Batcher(
Expand Down Expand Up @@ -214,7 +210,8 @@ def predict(rank, world_size, conf, p):

# y_pred allocation and results tracking
results = []
save_datetimes = [0] * len(forecasts)
# save_datetimes = [0] * len(forecasts)
init_datetimes = []

# model inference loop
for batch in data_loader:
Expand All @@ -230,9 +227,8 @@ def predict(rank, world_size, conf, p):
)
for i in range(batch_size)
]
save_datetimes[forecast_count : forecast_count + batch_size] = (
init_datetimes
)
# save_datetimes[forecast_count:forecast_count + batch_size] = init_datetimes
# save_datetimes

if "x_surf" in batch:
x = (
Expand Down Expand Up @@ -340,17 +336,15 @@ def predict(rank, world_size, conf, p):
(
all_upper_air,
all_single_level,
save_datetimes[
forecast_count + j
], # Use correct index for current batch item
init_datetimes[j],
lead_time_periods * forecast_step,
meta_data,
conf,
),
)
results.append(result)

print_str = f"Forecast: {forecast_count + 1 + j} "
print_str = f"{rank=:} Forecast: {forecast_count + 1 + j} "
print_str += f"Date: {utc_datetimes[j].strftime('%Y-%m-%d %H:%M:%S')} "
print_str += f"Hour: {forecast_step * lead_time_periods} "
print(print_str)
Expand All @@ -360,14 +354,16 @@ def predict(rank, world_size, conf, p):

# y_diag is not drawn in predict batcher, if diag is specified in config, it will not be in the input to the model
if history_len == 1:
x = y_pred[:, :-varnum_diag, ...].detach()
# x = y_pred[:, :-varnum_diag, ...].detach()
x = y_pred.detach()
else:
if static_dim_size == 0:
x_detach = x[:, :, 1:, ...].detach()
else:
x_detach = x[:, :-static_dim_size, 1:, ...].detach()

x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
# x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
x = torch.cat([x_detach, y_pred.detach()], dim=2)

# Memory cleanup
torch.cuda.empty_cache()
Expand All @@ -381,13 +377,10 @@ def predict(rank, world_size, conf, p):
y_pred = None
gc.collect()

if distributed:
torch.distributed.barrier()

forecast_count += batch_size

if distributed:
torch.distributed.barrier()
torch.distributed.destroy_process_group()

return 1

Expand Down Expand Up @@ -491,9 +484,9 @@ def predict(rank, world_size, conf, p):
predict_data_check(conf, print_summary=False)

# create a save location for rollout
assert (
"save_forecast" in conf["predict"]
), "Please specify the output dir through conf['predict']['save_forecast']"
assert "save_forecast" in conf["predict"], (
"Please specify the output dir through conf['predict']['save_forecast']"
)

forecast_save_loc = conf["predict"]["save_forecast"]
os.makedirs(forecast_save_loc, exist_ok=True)
Expand Down

0 comments on commit db62743

Please sign in to comment.