Skip to content

Commit

Permalink
monitoring all steps
Browse files Browse the repository at this point in the history
  • Loading branch information
Salva4 committed Jul 3, 2024
1 parent 1297a4b commit a6d44a4
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 6 deletions.
1 change: 1 addition & 0 deletions examples/linear_algebra/src/main_noDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def train_epoch(
print(f'rank={rank}, Batch fwd pass time: {batch_fwd_pass_end - batch_fwd_pass_start}')
print(f'rank={rank}, Batch bwd pass time: {batch_bwd_pass_end - batch_bwd_pass_start}')
if batch_idx == 11: sys.exit()
if batch_idx == 2: sys.exit()

predictions = output.argmax(dim=-1)
correct = (
Expand Down
2 changes: 1 addition & 1 deletion examples/linear_algebra/src/model_utils/F_dec_MHA_FF.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(
t1 = time.time()
FF_x = self.ff_block(x + MHA_x)
t2 = time.time()
if 1: print(f'MHA-time={t1-t0:.4f}, FF-time={t2-t1:.4f}')
if 0: print(f'MHA-time={t1-t0:.4f}, FF-time={t2-t1:.4f}')

return MHA_x + FF_x

Expand Down
2 changes: 1 addition & 1 deletion examples/linear_algebra/src/model_utils/F_dec_SA.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def forward(
x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask,
)
t1 = time.time()
if 1: print(f'DEC: SA-time={t1-t0:.4f}')
if 0: print(f'DEC: SA-time={t1-t0:.4f}')

return SA_x

Expand Down
2 changes: 1 addition & 1 deletion examples/linear_algebra/src/model_utils/F_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def forward(self, x, src_mask, src_key_padding_mask):
FF_x = self.ff_block(x + SA_x)
t2 = time.time()

if 1: print(f'ENC: SA_time={t1-t0:.4f}, FF_time={t2-t1:.4f}')
if 0: print(f'ENC: SA_time={t1-t0:.4f}, FF_time={t2-t1:.4f}')

return SA_x + FF_x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ def train_epoch(
optimizer.step()
if batch_scheduler is not None: batch_scheduler.step()

if 0:
if 1:
print(f'rank={rank}, Batch idx: {batch_idx}')
print(f'rank={rank}, Batch fwd pass time: {batch_fwd_pass_end - batch_fwd_pass_start}')
print(f'rank={rank}, Batch bwd pass time: {batch_bwd_pass_end - batch_bwd_pass_start}')
if batch_idx == 11: import sys; sys.exit()
if batch_idx == 2: import sys; sys.exit()

predictions = output.argmax(dim=-1)
correct = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def forward(self, x):
x = self.parallel_nn(x)
t1 = time.time()
x = self.compose(self.close_nn, x)
if 0: print(f'rank={self.rank}, CB-time: {t1 - t0} seconds')
if 1: print(f'lp_rank={self.rank}, CB-time: {t1 - t0} seconds')

return x

Expand Down Expand Up @@ -328,7 +328,7 @@ def parse_args():

# parser.add_argument('--ni_starting_level', type=int, default=0)
parser.add_argument('--ni_cfactor' , type=int, default=2)
parser.add_argument('--ni_num_levels' , type=int, default=2)
parser.add_argument('--ni_num_levels' , type=int, default=1)
parser.add_argument('--ni_interpolation' , type=str, default='constant')
parser.add_argument('--ni_interpolate_momentum', type=str, default='True')

Expand Down
32 changes: 32 additions & 0 deletions installation_torchbraid_parallel_env3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
cd
conda deactivate
rm -rf .cache/
rm -rf .local/lib/python3.9/site-packages/
module load daint-gpu cray-python
cd $SCRATCH
rm -rf envs/tb3_18_02_2024
rm -rf torchbraid3_18_02_2024
mkdir torchbraid3_18_02_2024
cd torchbraid3_18_02_2024
git clone https://github.com/Multilevel-NN/torchbraid.git -b transformer2
export MPICC=cc
conda create --prefix $SCRATCH/envs/tb3_18_02_2024 cudatoolkit-dev=11.7 pytorch-cuda=11.7 pytorch=1.13 -c conda-forge -c nvidia -c pytorch -y
conda activate /scratch/snx3000/msalvado/envs/tb3_18_02_2024
LDFLAGS="-L$CONDA_PREFIX/lib/ -lcudart -lcuda" MPICC=cc pip install mpi4py --no-cache --force-reinstall
/opt/python/3.9.4.1/bin/python3.9 -m pip install --upgrade pip
export MPICC=cc
cd torchbraid
vim setup.py # --> 58:'numpy==1.22', 59:'torch', 60:'torchvision', 28: [...,'MPICC=cc'] # 1.16.5 !!
# install_requires = [
'setuptools',
'mpi4py',
'cython==0.29.32',
'numpy',
'torch==2.0.1',
'torchvision==0.15.2',
'matplotlib'
]
cd ..
MPICC=cc CC=cc LDSHARED="cc -pthread -shared -Wl,-rpath,/opt/python/3.9.4.1/lib,-build-id -fPIC" pip install torchbraid/
export MPICH_RDMA_ENABLED_CUDA=1
srun -lu -N 4 -A c24 -C gpu python -u torchbraid/examples/mnist/mnist_script.py

0 comments on commit a6d44a4

Please sign in to comment.