Skip to content

Commit

Permalink
Log additional per-GPU information in model metadata files and GPU ut…
Browse files Browse the repository at this point in the history
…ilization on tensorboard. (#3712)
  • Loading branch information
justinxzhao authored Oct 11, 2023
1 parent 7c3a549 commit 6d74d21
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
36 changes: 26 additions & 10 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,31 @@ def kfold_cross_validate(
return kfold_cv_stats, kfold_split_indices


def _get_compute_description(backend) -> Dict:
"""Returns the compute description for the backend."""
compute_description = {"num_nodes": backend.num_nodes}

if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
compute_description.update(
{
"gpus_per_node": torch.cuda.device_count(),
"arch_list": torch.cuda.get_arch_list(),
"gencode_flags": torch.cuda.get_gencode_flags(),
"devices": {},
}
)
for i in range(torch.cuda.device_count()):
compute_description["devices"][i] = {
"gpu_type": torch.cuda.get_device_name(i),
"device_capability": torch.cuda.get_device_capability(i),
"device_properties": str(torch.cuda.get_device_properties(i)),
}

return compute_description


@PublicAPI
def get_experiment_description(
config,
Expand Down Expand Up @@ -2184,15 +2209,6 @@ def get_experiment_description(

description["config"] = config
description["torch_version"] = torch.__version__

gpu_info = {}
if torch.cuda.is_available():
# Assumption: All nodes are of the same instance type.
# TODO: fix for Ray where workers may be of different skus
gpu_info = {"gpu_type": torch.cuda.get_device_name(0), "gpus_per_node": torch.cuda.device_count()}

compute_description = {"num_nodes": backend.num_nodes, **gpu_info}

description["compute"] = compute_description
description["compute"] = _get_compute_description(backend)

return description
8 changes: 8 additions & 0 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,14 @@ def write_step_summary(cls, train_summary_writer, combined_loss, all_losses, ste
/ (1000**3),
global_step=step,
)

# Utilization.
# https://pytorch.org/docs/stable/generated/torch.cuda.utilization.html#torch.cuda.utilization
train_summary_writer.add_scalar(
f"cuda/device{i}/utilization",
torch.cuda.device(i).utilization(),
global_step=step,
)
train_summary_writer.flush()

def is_cpu_training(self):
Expand Down

0 comments on commit 6d74d21

Please sign in to comment.