Skip to content

Commit

Permalink
Reduce query indices
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent 3655ea7 commit 7c59820
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 150 deletions.
3 changes: 1 addition & 2 deletions examples/uci/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This directory contains scripts designed for training a regression model and con

## Training

To initiate the training of a regression model using the Concrete dataset, execute the following command:
To train a regression model on the Concrete dataset, run the following command:
```bash
python train.py --dataset_name concrete \
--dataset_dir ./data \
Expand All @@ -16,7 +16,6 @@ python train.py --dataset_name concrete \
--num_train_epochs 20 \
--seed 1004
```
Alternatively, you can download the model checkpoint.

# Influence Analysis

Expand Down
87 changes: 9 additions & 78 deletions examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import logging
import math
import os
from typing import Dict, Tuple
from typing import Tuple

import torch
import torch.nn.functional as F
from analyzer import Analyzer, prepare_model
from arguments import FactorArguments, ScoreArguments
from module.utils import wrap_tracked_modules
from task import Task
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -96,14 +95,12 @@ def compute_measurement(

def main():
args = parse_args()

logging.basicConfig(level=logging.INFO)

train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir)

model = construct_regression_mlp()

checkpoint_path = os.path.join(args.checkpoint_dir, "model.pth")
if not os.path.isfile(checkpoint_path):
raise ValueError(f"No checkpoint found at {checkpoint_path}.")
Expand All @@ -120,91 +117,25 @@ def main():
)
factor_args = FactorArguments(
strategy=args.factor_strategy,
covariance_data_partition_size=5,
covariance_module_partition_size=4,
)
# with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
# with record_function("covariance"):
# analyzer.fit_covariance_matrices(
# factors_name=args.factor_strategy,
# dataset=train_dataset,
# factor_args=factor_args,
# per_device_batch_size=args.batch_size,
# overwrite_output_dir=True,
# )
#
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# cov_factors = analyzer.fit_covariance_matrices(
# factors_name=args.factor_strategy,
# dataset=train_dataset,
# factor_args=factor_args,
# per_device_batch_size=args.batch_size,
# overwrite_output_dir=True,
# )
# print(cov_factors)

with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
with record_function("eigen"):
res = analyzer.perform_eigendecomposition(
factors_name=args.factor_strategy,
factor_args=factor_args,
overwrite_output_dir=True,
)
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# print(res)
res = analyzer.fit_lambda_matrices(
analyzer.fit_all_factors(
factors_name=args.factor_strategy,
dataset=train_dataset,
# factor_args=factor_args,
per_device_batch_size=None,
factor_args=factor_args,
overwrite_output_dir=True,
)
# print(res)
#
score_args = ScoreArguments(data_partition_size=2, module_partition_size=2)
analyzer.compute_pairwise_scores(
scores_name="hello",

scores = analyzer.compute_pairwise_scores(
scores_name="pairwise",
factors_name=args.factor_strategy,
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=16,
per_device_train_batch_size=8,
score_args=score_args,
overwrite_output_dir=True,
)
# scores = analyzer.load_pairwise_scores(scores_name="hello")
# print(scores)
#
# analyzer.compute_self_scores(
# scores_name="hello",
# factors_name=args.factor_strategy,
# # query_dataset=eval_dataset,
# train_dataset=train_dataset,
# # per_device_query_batch_size=16,
# per_device_train_batch_size=8,
# overwrite_output_dir=True,
# )
# # scores = analyzer.load_self_scores(scores_name="hello")
# # print(scores)

# analyzer.fit_all_factors(
# factor_name=args.factor_strategy,
# dataset=train_dataset,
# factor_args=factor_args,
# per_device_batch_size=None,
# overwrite_output_dir=True,
# )
#
# score_name = "full_pairwise"
# analyzer.compute_pairwise_scores(
# score_name=score_name,
# query_dataset=eval_dataset,
# per_device_query_batch_size=len(eval_dataset),
# train_dataset=train_dataset,
# per_device_train_batch_size=len(train_dataset),
# )
# scores = analyzer.load_pairwise_scores(score_name=score_name)
# print(scores.shape)
logging.info(f"Scores: {scores}")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/uci/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def get_regression_dataset(
data_name: str,
split: str,
indices: List[int] = None,
data_path: str = "data/",
dataset_dir: str = "data/",
) -> Dataset:
assert split in ["train", "eval_train", "valid"]

# Load the dataset from the `.data` file.
data = np.loadtxt(os.path.join(data_path, data_name + ".data"), delimiter=None)
data = np.loadtxt(os.path.join(dataset_dir, data_name + ".data"), delimiter=None)
data = data.astype(np.float32)

# Shuffle the dataset.
Expand Down
101 changes: 42 additions & 59 deletions examples/uci/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse
import logging
import os
from torch.utils import data

import torch
import torch.nn.functional as F
from torch import nn
from accelerate.utils import set_seed
from torch import nn
from torch.utils import data
from tqdm import tqdm

from examples.uci.pipeline import construct_regression_mlp, get_regression_dataset
Expand Down Expand Up @@ -82,7 +83,13 @@ def parse_args():
return args


def train(dataset: data.Dataset, batch_size: int, num_train_epochs: int, learning_rate: float, weight_decay: float) -> nn.Module:
def train(
dataset: data.Dataset,
batch_size: int,
num_train_epochs: int,
learning_rate: float,
weight_decay: float,
) -> nn.Module:
train_dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -110,6 +117,25 @@ def train(dataset: data.Dataset, batch_size: int, num_train_epochs: int, learnin
return model


def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> float:
dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)

model.eval()
total_loss = 0
for batch in dataloader:
with torch.no_grad():
inputs, targets = batch
outputs = model(inputs)
loss = F.mse_loss(outputs, targets, reduction="sum")
total_loss += loss.detach().float()

return total_loss.item() / len(dataloader.dataset)


def main():
args = parse_args()
Expand All @@ -120,68 +146,25 @@ def main():
if args.seed is not None:
set_seed(args.seed)

train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
train_dataloader = data.DataLoader(
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", dataset_dir=args.dataset_dir)

model = train(
dataset=train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=True,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
)
model = construct_regression_mlp()
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

logger.info("Start training the model.")
model.train()
for epoch in range(args.num_train_epochs):
total_loss = 0
with tqdm(train_dataloader, unit="batch") as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")
inputs, targets = batch
outputs = model(inputs)
loss = F.mse_loss(outputs, targets)
total_loss += loss.detach().float()
loss.backward()
optimizer.step()
optimizer.zero_grad()
tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))

logger.info("Start evaluating the model.")
model.eval()
train_eval_dataset = get_regression_dataset(
data_name=args.dataset_name, split="eval_train", data_path=args.dataset_dir
eval_train_dataset = get_regression_dataset(
data_name=args.dataset_name, split="eval_train", dataset_dir=args.dataset_dir
)
train_eval_dataloader = DataLoader(
dataset=train_eval_dataset,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
)

total_loss = 0
for batch in train_eval_dataloader:
with torch.no_grad():
inputs, targets = batch
outputs = model(inputs)
loss = F.mse_loss(outputs, targets, reduction="sum")
total_loss += loss.detach().float()
logger.info(f"Train loss {total_loss.item() / len(train_eval_dataloader.dataset)}")
train_loss = evaluate(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size)
logger.info(f"Train loss: {train_loss}")

total_loss = 0
for batch in eval_dataloader:
with torch.no_grad():
inputs, targets = batch
outputs = model(inputs)
loss = F.mse_loss(outputs, targets, reduction="sum")
total_loss += loss.detach().float()
logger.info(f"Evaluation loss {total_loss.item() / len(eval_dataloader.dataset)}")
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", dataset_dir=args.dataset_dir)
eval_loss = evaluate(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size)
logger.info(f"Evaluation loss: {eval_loss}")

if args.checkpoint_dir is not None:
torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth"))
Expand Down
13 changes: 12 additions & 1 deletion kronfluence/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

from accelerate.utils import extract_model_from_parallel
from kronfluence.module.constants import FACTOR_TYPE

from factor.config import FactorConfig
from safetensors.torch import save_file
from torch import nn
from torch.utils import data
Expand Down Expand Up @@ -119,7 +122,7 @@ def fit_all_factors(
dataloader_kwargs: Optional[DataLoaderKwargs] = None,
factor_args: Optional[FactorArguments] = None,
overwrite_output_dir: bool = False,
) -> None:
) -> Optional[FACTOR_TYPE]:
"""Computes all necessary factors for the given factor strategy. As an example, EK-FAC
requires (1) computing covariance matrices, (2) performing Eigendecomposition, and
(3) computing Lambda (corrected-eigenvalues) matrices.
Expand Down Expand Up @@ -161,3 +164,11 @@ def fit_all_factors(
factor_args=factor_args,
overwrite_output_dir=overwrite_output_dir,
)

if factor_args is None:
factor_args = FactorArguments()
strategy = factor_args.strategy
factor_config = FactorConfig.CONFIGS[strategy]
return self._load_all_required_factors(
factors_name=factors_name, strategy=strategy, factor_config=factor_config
)
2 changes: 1 addition & 1 deletion tests/gpu_tests/cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
construct_mnist_mlp,
get_mnist_dataset,
)
from tests.gpu_tests.prepare_tests import TRAIN_INDICES, QUERY_INDICES
from tests.gpu_tests.prepare_tests import QUERY_INDICES, TRAIN_INDICES
from tests.utils import check_tensor_dict_equivalence

logging.basicConfig(level=logging.DEBUG)
Expand Down
11 changes: 4 additions & 7 deletions tests/gpu_tests/prepare_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
get_mnist_dataset,
)


# Pick difficult cases where the dataset is not perfectly divisible by batch size.
TRAIN_INDICES = 59_999
QUERY_INDICES = 50
TRAIN_INDICES = 5_003
QUERY_INDICES = 51


def train() -> None:
Expand Down Expand Up @@ -82,6 +81,8 @@ def run_analysis() -> None:

train_dataset = get_mnist_dataset(split="train", data_path="data")
eval_dataset = get_mnist_dataset(split="valid", data_path="data")
train_dataset = Subset(train_dataset, indices=list(range(TRAIN_INDICES)))
eval_dataset = Subset(eval_dataset, indices=list(range(QUERY_INDICES)))

task = ClassificationTask()
model = model.double()
Expand All @@ -99,7 +100,6 @@ def run_analysis() -> None:
gradient_covariance_dtype=torch.float64,
lambda_dtype=torch.float64,
lambda_iterative_aggregate=False,
lambda_max_examples=1_000
)
analyzer.fit_all_factors(
factors_name="single_gpu",
Expand All @@ -119,8 +119,6 @@ def run_analysis() -> None:
factors_name="single_gpu",
query_dataset=eval_dataset,
train_dataset=train_dataset,
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
Expand All @@ -130,7 +128,6 @@ def run_analysis() -> None:
scores_name="single_gpu",
factors_name="single_gpu",
train_dataset=train_dataset,
train_indices=list(range(TRAIN_INDICES)),
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
Expand Down

0 comments on commit 7c59820

Please sign in to comment.