Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Shenggan committed Feb 28, 2022
1 parent 62ed1c4 commit d5f3875
Show file tree
Hide file tree
Showing 32 changed files with 5,643 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,10 @@ dmypy.json

# Pyre type checker
.pyre/

# vscode
.vscode/

# setup
dist/
build/
59 changes: 59 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,61 @@
![](/assets/fold.jpg)

# FastFold

![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold.svg)](https://github.com/hpcaitech/FastFold/blob/master/LICENSE)
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)

Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters

FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.

1. Excellent kernel performance on GPU platform
2. Supporting Dynamic Axial Parallelism(DAP)
* Break the memory limit of single GPU and reduce the overall training time
* Distributed inference can significantly speed up inference and make extremely long sequence inference possible
3. Ease of use
* Replace a few lines and you can use FastFold in your project
* You don't need to care about how the parallel part is implemented

## Installation

You will need Python 3.8 or later and [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above when you are installing from source.

We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda:

```
conda create -n fastfold python=3.8
conda activate fastfold
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```

You can get the FastFold source and install it with setuptools:

```shell
git clone https://github.com/hpcaitech/FastFold
cd FastFold
python setup.py install --cuda_ext
```

## Performance Benchmark

We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.

```shell
cd ./benchmark
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
```

If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:

```shell
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```

## Cite us

Cite this paper, if you use FastFold in your research publication.

```
```
Binary file added assets/fold.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
191 changes: 191 additions & 0 deletions benchmark/perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import argparse
import os

import torch
import torch.nn as nn

from fastfold.distributed import init_shadowcore
from fastfold.model import Evoformer


def main():

parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int)
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input')
parser.add_argument('--res-length',
default=256,
type=int,
help='Start Range of Number of Sequences')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers',
default=12,
type=int,
help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.')

args = parser.parse_args()

args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1

args.local_rank = int(os.environ['LOCAL_RANK'])

torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.global_rank = torch.distributed.get_rank()
print(
'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.global_rank, args.world_size))
init_shadowcore(args.tensor_model_parallel_size)

precision = torch.bfloat16
if args.tensor_model_parallel_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16

if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')

torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)

if args.openfold:
from openfold.model.evoformer import EvoformerBlock

class OpenFoldEvoformer(nn.Module):

def __init__(self, d_node, d_pair):
super(OpenFoldEvoformer, self).__init__()
self.d_node = d_node
self.d_pair = d_pair

self.c_hidden_msa_att = int(d_node / 8)
self.c_hidden_pair_att = int(d_pair / 8)

self.EvoformerBlock = EvoformerBlock(c_m=d_node,
c_z=d_pair,
c_hidden_msa_att=self.c_hidden_msa_att,
c_hidden_opm=self.c_hidden_msa_att,
c_hidden_mul=self.d_pair,
c_hidden_pair_att=self.c_hidden_pair_att,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
inf=1e9,
eps=1e-10)

def forward(self, node, pair, node_mask, pair_mask):
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
return node, pair

attn_layers = []
for idx in range(0, args.layers):
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)

start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials):
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))

inputs_node = torch.randn(args.batch_size,
args.msa_length // args.tensor_model_parallel_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.tensor_model_parallel_size,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)

if args.prof:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1,
warmup=args.warmup_trials,
active=args.trials,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
profile_memory=False,
record_shapes=False,
with_stack=False)
prof.start()

for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials

torch.distributed.barrier()
torch.cuda.synchronize()

if evt_idx >= 0:
start_evt_fwd[evt_idx].record()

for lyr_idx in range(0, args.layers):
layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)

torch.cuda.synchronize()

if evt_idx >= 0:
start_evt_bwd[evt_idx].record()

if not args.fwd:
layer_inputs[1].backward(grads_node)

if evt_idx >= 0:
stop_evt_bwd[evt_idx].record()

if args.prof:
prof.step()

if args.prof:
prof.stop()

torch.distributed.barrier()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials):
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])

print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.batch_size, args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))


if __name__ == '__main__':
main()
Empty file added fastfold/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions fastfold/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .core import (init_shadowcore, shadowcore_is_initialized, get_tensor_model_parallel_group,
get_data_parallel_group, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank, get_data_parallel_world_size,
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)

__all__ = [
'init_shadowcore', 'shadowcore_is_initialized', 'get_tensor_model_parallel_group',
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',
'reduce', 'gather', 'col_to_row', 'row_to_col'
]
Loading

0 comments on commit d5f3875

Please sign in to comment.