Skip to content

Commit

Permalink
LA: scaling w/o masks (similar result) + monitoring all steps
Browse files Browse the repository at this point in the history
  • Loading branch information
Salva4 committed Jun 26, 2024
1 parent 5314f11 commit 1297a4b
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 7 deletions.
16 changes: 12 additions & 4 deletions examples/linear_algebra/src/main_noDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@
from torchvision import datasets, transforms
import sys

# from network_architecture import parse_args, ParallelNet, SerialNet; print('split')
from network_architecture_joint import parse_args, ParallelNet, SerialNet; print('joint')
from network_architecture import parse_args, ParallelNet, SerialNet; print('split')
# from network_architecture_joint import parse_args, ParallelNet, SerialNet; print('joint')
# from network_architecture_semijoint import parse_args, ParallelNet, SerialNet; print('semijoint')
# from network_architecture_womasks import parse_args, ParallelNet, SerialNet; print('split-womasks')
from mpi4py import MPI

from cosine_warmup_scheduler import CosineWarmupScheduler
Expand Down Expand Up @@ -293,10 +294,17 @@ def main():
).to(device)

# Detailed XBraid timings are output to these files for the forward and backward phases
model.parallel_nn.fwd_app.setBraidTimers(flag=1)
model.parallel_nn.fwd_app.setTimerFile(
f'b_fwd_s_{args.steps}_bs_{args.batch_size}_p_{num_procs}')
#f'b_fwd_s_{args.steps}_bs_{args.batch_size}_p_{num_procs}'
'/users/msalvado/fwd'
)
model.parallel_nn.bwd_app.setBraidTimers(flag=1)
model.parallel_nn.bwd_app.setTimerFile(
f'b_bwd_s_{args.steps}_bs_{args.batch_size}_p_{num_procs}')
#f'b_bwd_s_{args.steps}_bs_{args.batch_size}_p_{num_procs}'
'/users/msalvado/bwd'
)
print('model.parallel_nn.bwd_app braid and timers initialized')

else:
assert num_procs == 1, 'If enforce_serial, num_procs must be 1'
Expand Down
5 changes: 5 additions & 0 deletions examples/linear_algebra/src/model/model_utils/F_dec_MHA_FF.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_dec_MHA_FF(nn.TransformerDecoderLayer):
Expand All @@ -10,11 +11,15 @@ def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first):
def forward(
self, x, memory, mem_key_padding_mask,
):
t0 = time.time()
MHA_x = self.mha_block(
x, mem=memory, attn_mask=None,
key_padding_mask=mem_key_padding_mask,
)
t1 = time.time()
FF_x = self.ff_block(x + MHA_x)
t2 = time.time()
if 1: print(f'MHA-time={t1-t0:.4f}, FF-time={t2-t1:.4f}')

return MHA_x + FF_x

Expand Down
4 changes: 4 additions & 0 deletions examples/linear_algebra/src/model/model_utils/F_dec_SA.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_dec_SA(nn.TransformerDecoderLayer):
Expand All @@ -9,9 +10,12 @@ def __init__(self, d_model, nhead, dropout, batch_first):
def forward(
self, x, tgt_mask, tgt_key_padding_mask,
):
t0 = time.time()
SA_x = self.sa_block(
x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask,
)
t1 = time.time()
if 1: print(f'DEC: SA-time={t1-t0:.4f}')

return SA_x

Expand Down
6 changes: 6 additions & 0 deletions examples/linear_algebra/src/model/model_utils/F_enc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_enc(nn.TransformerEncoderLayer):
Expand All @@ -8,10 +9,15 @@ def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first):
)

def forward(self, x, src_mask, src_key_padding_mask):
t0 = time.time()
SA_x = self.sa_block(
x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask,
)
t1 = time.time()
FF_x = self.ff_block(x + SA_x)
t2 = time.time()

if 1: print(f'ENC: SA_time={t1-t0:.4f}, FF_time={t2-t1:.4f}')

return SA_x + FF_x

Expand Down
5 changes: 5 additions & 0 deletions examples/linear_algebra/src/model_utils/F_dec_MHA_FF.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_dec_MHA_FF(nn.TransformerDecoderLayer):
Expand All @@ -10,11 +11,15 @@ def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first):
def forward(
self, x, memory, mem_key_padding_mask,
):
t0 = time.time()
MHA_x = self.mha_block(
x, mem=memory, attn_mask=None,
key_padding_mask=mem_key_padding_mask,
)
t1 = time.time()
FF_x = self.ff_block(x + MHA_x)
t2 = time.time()
if 1: print(f'MHA-time={t1-t0:.4f}, FF-time={t2-t1:.4f}')

return MHA_x + FF_x

Expand Down
4 changes: 4 additions & 0 deletions examples/linear_algebra/src/model_utils/F_dec_SA.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_dec_SA(nn.TransformerDecoderLayer):
Expand All @@ -9,9 +10,12 @@ def __init__(self, d_model, nhead, dropout, batch_first):
def forward(
self, x, tgt_mask, tgt_key_padding_mask,
):
t0 = time.time()
SA_x = self.sa_block(
x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask,
)
t1 = time.time()
if 1: print(f'DEC: SA-time={t1-t0:.4f}')

return SA_x

Expand Down
6 changes: 6 additions & 0 deletions examples/linear_algebra/src/model_utils/F_enc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import torch.nn as nn

class F_enc(nn.TransformerEncoderLayer):
Expand All @@ -8,10 +9,15 @@ def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first):
)

def forward(self, x, src_mask, src_key_padding_mask):
t0 = time.time()
SA_x = self.sa_block(
x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask,
)
t1 = time.time()
FF_x = self.ff_block(x + SA_x)
t2 = time.time()

if 1: print(f'ENC: SA_time={t1-t0:.4f}, FF_time={t2-t1:.4f}')

return SA_x + FF_x

Expand Down
13 changes: 10 additions & 3 deletions examples/linear_algebra/src/network_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,23 +331,30 @@ def forward(self, src, tgt):
# (x, y, tgt_attention_mask, src_padding_mask, tgt_padding_mask,
# mem_padding_mask,) = self.compose(self.open_nn, src, tgt)
# x = torch.stack((x, y))
t0_open_layer_time = time.time()
x = self.compose(self.open_nn, src, tgt)
t1_open_layer_time = time.time()

t0_masks_comm_time = time.time()
tgt_attention_mask, src_padding_mask, tgt_padding_mask, mem_padding_mask = \
self.comm_lp.bcast([tgt_attention_mask, src_padding_mask,
tgt_padding_mask , mem_padding_mask,], root=0)
t1_masks_comm_time = time.time()

t0_continuous_block_time = time.time()
x = self.parallel_nn(x)
t1_continuous_block_time = time.time()

t0_close_layer_time = time.time()
mem, y = x
y = self.compose(self.close_nn, y)
t1_close_layer_time = time.time()

lp_rank = self.comm_lp.Get_rank()
dp_rank = self.comm_dp.Get_rank() if self.comm_dp is not None else None
if 0: print(f'''lp_rank={lp_rank}, dp_rank={dp_rank}: {
t1_continuous_block_time - t0_continuous_block_time :.4f
}''')
if 1:
# print(f'''lp_rank={lp_rank}, dp_rank={dp_rank}: {t1_continuous_block_time - t0_continuous_block_time :.4f}''')
print(f'''lp_rank={lp_rank}, dp_rank={dp_rank}, open={t1_open_layer_time - t0_open_layer_time:.4f}, masks-comm={t1_masks_comm_time - t0_masks_comm_time}, CB={t1_continuous_block_time - t0_continuous_block_time :.4f}, close={t1_close_layer_time - t0_close_layer_time}''')

return y

Expand Down
Loading

0 comments on commit 1297a4b

Please sign in to comment.