Skip to content

Commit

Permalink
Add linting for all files
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent 99caf70 commit fccb8b0
Show file tree
Hide file tree
Showing 56 changed files with 504 additions and 1,238 deletions.
Binary file removed .assets/kronfluence.png
Binary file not shown.
4 changes: 4 additions & 0 deletions .assets/kronfluence.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 6 additions & 20 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,9 @@ jobs:
run: |
isort --profile black kronfluence
black:
runs-on: ubuntu-latest

steps:
- name: Checkout Repository
uses: actions/checkout@v2

- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9

- name: Install black
run: |
pip install --upgrade pip
pip install black==24.1.1
- name: Run black
run: |
black --check kronfluence
jobs:
actionlint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: reviewdog/action-actionlint@v1
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ We welcome contributions to the `kronfluence` project. Whether it's bug fixes, f

## Setting Up Development Environment

To contribute to `kronfluence`, you will need to set up a development environment on your machine. This setup includes all the dependencies required for linting, testing, and documentation.
To contribute to `kronfluence`, you will need to set up a development environment on your machine.
This setup includes all the dependencies required for linting and testing.

```bash
git clone https://github.com/pomonam/kronfluence.git
Expand Down
5 changes: 5 additions & 0 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Kronfluence: Technical Documentation & FAQs

(To be added.)

## Supported Modules

Kronfluence only supports influence computation on supported `nn.Module`. The following modules are supported:
1. `nn.Linear` and `nn.Conv2d`

## Supported Strategies

- Identity, diagonal, KFAC, EKFAC
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<p align="center">
<a href="#"><img width="380" img src=".assets/kronfluence.png" alt="Kronfluence Logo"/></a>
<a href="#"><img width="380" img src=".assets/kronfluence.svg" alt="Kronfluence Logo"/></a>
</p>

<p align="center">
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
isort==5.13.2
pylint==3.0.3
pytest==8.0.0
black==24.1.1
ruff==0.3.0
datasets>=2.17.0
transformers>=4.37.2
16 changes: 4 additions & 12 deletions examples/cifar/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,7 @@ def get_cifar10_dataset(
):
assert split in ["train", "eval_train", "valid"]

normalize = torchvision.transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)
)
normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))

if split in ["train", "eval_train"]:
transforms = torchvision.transforms.Compose(
Expand All @@ -114,9 +112,7 @@ def get_cifar10_dataset(

if split == "train":
transform_config = [
torchvision.transforms.RandomResizedCrop(
size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)
),
torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)),
torchvision.transforms.RandomHorizontalFlip(),
]
transform_config.extend([torchvision.transforms.ToTensor(), normalize])
Expand Down Expand Up @@ -180,9 +176,7 @@ def get_cifar10_dataloader(

if do_corrupt:
if split == "valid":
raise NotImplementedError(
"Performing corruption on the validation dataset is not supported."
)
raise NotImplementedError("Performing corruption on the validation dataset is not supported.")
num_corrupt = math.ceil(len(dataset) * 0.1)
original_targets = np.array(copy.deepcopy(dataset.targets[:num_corrupt]))
new_targets = torch.randint(
Expand All @@ -197,9 +191,7 @@ def get_cifar10_dataloader(
size=new_targets[new_targets == original_targets].shape,
generator=torch.Generator().manual_seed(0),
).numpy()
new_targets[new_targets == original_targets] = (
new_targets[new_targets == original_targets] + offsets
) % 10
new_targets[new_targets == original_targets] = (new_targets[new_targets == original_targets] + offsets) % 10
assert (new_targets == original_targets).sum() == 0
dataset.targets[:num_corrupt] = list(new_targets)

Expand Down
12 changes: 3 additions & 9 deletions examples/glue/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,17 @@ def get_glue_dataset(
num_labels = len(label_list)
assert num_labels == 2

tokenizer = AutoTokenizer.from_pretrained(
"bert-base-cased", use_fast=True, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True, trust_remote_code=True)

sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS[data_name]
padding = "max_length"
max_seq_length = 128

def preprocess_function(examples):
texts = (
(examples[sentence1_key],)
if sentence2_key is None
else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(
*texts, padding=padding, max_length=max_seq_length, truncation=True
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*texts, padding=padding, max_length=max_seq_length, truncation=True)
if "label" in examples:
result["labels"] = examples["label"]
return result
Expand Down
8 changes: 2 additions & 6 deletions examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@


def parse_args():
parser = argparse.ArgumentParser(
description="Train classification models on MNIST datasets."
)
parser = argparse.ArgumentParser(description="Train classification models on MNIST datasets.")

parser.add_argument(
"--dataset_name",
Expand Down Expand Up @@ -95,9 +93,7 @@ def main():
set_seed(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = get_glue_dataset(
data_name=args.dataset_name, split="train", data_path=args.dataset_dir
)
train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=args.train_batch_size,
Expand Down
16 changes: 4 additions & 12 deletions examples/imagenet/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@


def parse_args():
parser = argparse.ArgumentParser(
description="Influence analysis on ImageNet datasets."
)
parser = argparse.ArgumentParser(description="Influence analysis on ImageNet datasets.")

parser.add_argument(
"--dataset_dir",
Expand Down Expand Up @@ -57,9 +55,7 @@ def parse_args():


class ClassificationTask(Task):
def compute_model_output(
self, batch: BATCH_DTYPE, model: nn.Module
) -> torch.Tensor:
def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Tensor:
inputs, _ = batch
return model(inputs)

Expand Down Expand Up @@ -88,15 +84,11 @@ def compute_measurement(
) -> torch.Tensor:
_, labels = batch

bindex = torch.arange(outputs.shape[0]).to(
device=outputs.device, non_blocking=False
)
bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False)
logits_correct = outputs[bindex, labels]

cloned_logits = outputs.clone()
cloned_logits[bindex, labels] = torch.tensor(
-torch.inf, device=outputs.device, dtype=outputs.dtype
)
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()
Expand Down
20 changes: 5 additions & 15 deletions examples/imagenet/ddp_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@


def parse_args():
parser = argparse.ArgumentParser(
description="Influence analysis on ImageNet datasets."
)
parser = argparse.ArgumentParser(description="Influence analysis on ImageNet datasets.")

parser.add_argument(
"--dataset_dir",
Expand Down Expand Up @@ -66,9 +64,7 @@ def parse_args():


class ClassificationTask(Task):
def compute_model_output(
self, batch: BATCH_DTYPE, model: nn.Module
) -> torch.Tensor:
def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Tensor:
inputs, _ = batch
return model(inputs)

Expand Down Expand Up @@ -97,15 +93,11 @@ def compute_measurement(
) -> torch.Tensor:
_, labels = batch

bindex = torch.arange(outputs.shape[0]).to(
device=outputs.device, non_blocking=False
)
bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False)
logits_correct = outputs[bindex, labels]

cloned_logits = outputs.clone()
cloned_logits[bindex, labels] = torch.tensor(
-torch.inf, device=outputs.device, dtype=outputs.dtype
)
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()
Expand All @@ -132,9 +124,7 @@ def main():
model = prepare_model(model, task)

model = model.to(device=device)
model = DistributedDataParallel(
model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK
)
model = DistributedDataParallel(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)

analyzer = Analyzer(
analysis_name=args.analysis_name,
Expand Down
12 changes: 3 additions & 9 deletions examples/imagenet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@


def construct_resnet50() -> nn.Module:
return torchvision.models.resnet50(
weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
return torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)


def get_imagenet_dataset(
Expand All @@ -20,15 +18,11 @@ def get_imagenet_dataset(
) -> Dataset:
assert split in ["train", "eval_train", "valid"]

normalize = torchvision.transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
normalize = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

if split == "train":
transform_config = [
torchvision.transforms.RandomResizedCrop(
size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)
),
torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)),
torchvision.transforms.RandomHorizontalFlip(),
]
transform_config.extend([torchvision.transforms.ToTensor(), normalize])
Expand Down
12 changes: 3 additions & 9 deletions examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,8 @@ def main():
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

train_dataset = get_regression_dataset(
data_name=args.dataset_name, split="train", data_path=args.dataset_dir
)
eval_dataset = get_regression_dataset(
data_name=args.dataset_name, split="valid", data_path=args.dataset_dir
)
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)

model = construct_regression_mlp()

Expand Down Expand Up @@ -147,9 +143,7 @@ def main():
overwrite_output_dir=True,
)

with profile(
activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True
) as prof:
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
with record_function("eigen"):
analyzer.perform_eigendecomposition(
factors_name=args.factor_strategy,
Expand Down
4 changes: 1 addition & 3 deletions examples/uci/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def get_regression_dataset(
y_train_scaled.astype(np.float32),
)
else:
dataset = RegressionDataset(
x_val_scaled.astype(np.float32), y_val_scaled.astype(np.float32)
)
dataset = RegressionDataset(x_val_scaled.astype(np.float32), y_val_scaled.astype(np.float32))

if indices is not None:
dataset = torch.utils.data.Subset(dataset, indices)
Expand Down
16 changes: 4 additions & 12 deletions examples/uci/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@


def parse_args():
parser = argparse.ArgumentParser(
description="Train regression models on UCI datasets."
)
parser = argparse.ArgumentParser(description="Train regression models on UCI datasets.")

parser.add_argument(
"--dataset_name",
Expand Down Expand Up @@ -93,19 +91,15 @@ def main():
if args.seed is not None:
set_seed(args.seed)

train_dataset = get_regression_dataset(
data_name=args.dataset_name, split="train", data_path=args.dataset_dir
)
train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=True,
)
model = construct_regression_mlp()
optimizer = torch.optim.SGD(
model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay
)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

logger.info("Start training the model.")
model.train()
Expand Down Expand Up @@ -134,9 +128,7 @@ def main():
shuffle=False,
drop_last=False,
)
eval_dataset = get_regression_dataset(
data_name=args.dataset_name, split="valid", data_path=args.dataset_dir
)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=args.eval_batch_size,
Expand Down
8 changes: 2 additions & 6 deletions kronfluence/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def prepare_model(
return model


class Analyzer(
CovarianceComputer, EigenComputer, PairwiseScoreComputer, SelfScoreComputer
):
class Analyzer(CovarianceComputer, EigenComputer, PairwiseScoreComputer, SelfScoreComputer):
"""
Handles the computation of all preconditioning factors (e.g., covariance and Lambda matrices for EKFAC)
and influence scores for a given PyTorch model.
Expand Down Expand Up @@ -98,9 +96,7 @@ def _save_model(self) -> None:
self.logger.info(f"Found existing saved model at {model_save_path}.")
# Load the existing model's state_dict for comparison.
loaded_state_dict = load_file(model_save_path)
if not verify_models_equivalence(
loaded_state_dict, extracted_model.state_dict()
):
if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()):
error_msg = (
"Detected a difference between the current model and the one saved at "
f"{model_save_path}. Consider using a different `analysis_name` to "
Expand Down
Loading

0 comments on commit fccb8b0

Please sign in to comment.