Skip to content

Commit

Permalink
constant & linear interp. debugged for #Nodes=1
Browse files Browse the repository at this point in the history
  • Loading branch information
Salva4 committed May 29, 2024
1 parent 587c979 commit 5cf97cc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 48 deletions.
101 changes: 56 additions & 45 deletions examples/morphological_classification/src/main_noDP_NI_callibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,18 @@ def interpolate_weights(
if num_procs == 1:
root_print(rank, 'Code shortcut: #Nodes=1')

def interpolate_weights_from_layers(
coarse_layers, interpolation_coefficient, mode,
):
def interpolate_weights_from_layers(coarse_model_child, k, j, cf, mode):
if mode == 'constant':
coarse_layer, = coarse_layers
return coarse_layer.parameters()
coarse_layer = coarse_model_child.layer_models[k]
# return coarse_layer.parameters()
for coarse_layer_parameter in coarse_layer.parameters():
yield coarse_layer_parameter

elif mode == 'linear':
coarse_layer_on_left, coarse_layer_on_right = coarse_layers

if coarse_layer_on_right is None:
coarse_layer_on_right = coarse_layer_on_left
interpolation_coefficient = j/cf
coarse_layer_on_left = coarse_model_child.layer_models[k]
coarse_layer_on_right = \
[*coarse_model_child.layer_models, coarse_layer_on_left][k+1]

assert len(list(coarse_layer_on_left .parameters())) == \
len(list(coarse_layer_on_right.parameters())) # delete for memory efficiency?
Expand All @@ -295,22 +295,21 @@ def interpolate_weights_from_layers(
+ interpolation_coefficient * \
coarse_layer_on_right_parameter

def interpolate_momentum_from_layers(
coarse_layers, interpolation_coefficient, mode,
):
def interpolate_momentum_from_layers(coarse_model_child, k, j, cf, mode):
if mode == 'constant':
coarse_layer, = coarse_layers
coarse_layer = coarse_model_child.layer_models[k]

for coarse_layer_parameter in coarse_layer.parameters():
coarse_layer_parameter_momentum = \
coarse_optimizer.state[coarse_layer_parameter] \
.get('momentum_buffer')
yield coarse_layer_parameter_momentum

elif mode == 'linear':
coarse_layer_on_left, coarse_layer_on_right = coarse_layers

if coarse_layer_on_right is None:
coarse_layer_on_right = coarse_layer_on_left
interpolation_coefficient = j/cf
coarse_layer_on_left = coarse_model_child.layer_models[k]
coarse_layer_on_right = \
[*coarse_model_child.layer_models, coarse_layer_on_left][k+1]

assert len(list(coarse_layer_on_left .parameters())) == \
len(list(coarse_layer_on_right.parameters())) # delete for memory efficiency?
Expand All @@ -329,14 +328,13 @@ def interpolate_momentum_from_layers(
+ interpolation_coefficient * \
coarse_layer_on_right_parameter_momentum

# def clone(x): return x.clone() if x is not None else None

def replace_layer_weights(
old_parameters, new_parameters, new_momentums=None,
):
for (old_parameter, new_parameter) in zip(
old_parameters, new_parameters,
): old_parameter[:] = new_parameter[:]
):
old_parameter[:] = new_parameter[:]

if new_momentums is not None:
fine_optimizer.state[old_parameter]['momentum_buffer'] = \
Expand All @@ -354,23 +352,17 @@ def replace_layer_weights(
N_coarse = len(coarse_model_child.layer_models)

for k in range(N_coarse):
coarse_layer_on_left = coarse_model_child.layer_models[k]
coarse_layer_on_right = [*coarse_model_child.layer_models, None] \
[k+1]
for j in range(cf):
fine_layer_idx = cf*k + j
if fine_layer_idx >= len(fine_model_child.layer_models): return
fine_layer = fine_model_child.layer_models[fine_layer_idx]
interpolation_coefficient = j/cf

lp_new_parameters = interpolate_weights_from_layers(
layers=(coarse_layer_on_left, coarse_layer_on_right),
interpolation_coefficient=interpolation_coefficient,
coarse_model_child=coarse_model_child, k=k, j=j, cf=cf,
mode=interpolation_mode,
)
lp_new_momentums = interpolate_momentum_from_layers(
layers=(coarse_layer_on_left, coarse_layer_on_right),
interpolation_coefficient=interpolation_coefficient,
coarse_model_child=coarse_model_child, k=k, j=j, cf=cf,
mode=interpolation_mode,
) if interpolate_momentum else None

Expand All @@ -391,8 +383,6 @@ def replace_layer_weights(
new_momentums=open_close_new_momentums,
)

# else: raise Exception('Unknown interpolation-mode')

else: raise Exception('Remaining to implement: linear interp. w/ momentum w/ #Nodes > 1')

# rank_coarse_model_lp_parameters = []
Expand Down Expand Up @@ -510,6 +500,28 @@ def replace_layer_weights(

root_print(rank, '-> Done.')

def print_weights(rank, model):
for child in model.children():
if isinstance(child, torchbraid.LayerParallel):
for k, layer in enumerate(child.layer_models):
print(f'lp-layer={k}')
for p in layer.parameters():
root_print(rank, p.ravel()[:3].tolist() + p.ravel()[-3:].tolist())
else:
print(child)
for p in child.parameters():
root_print(rank, p.ravel()[:3].tolist() + p.ravel()[-3:].tolist())

def print_momentum(rank, model, optimizer):
for child in model.children():
if isinstance(child, torchbraid.LayerParallel):
for k, layer in enumerate(child.layer_models):
print(f'lp-layer={k}')
for p in layer.parameters(): root_print(rank, optimizer.state[p])
else:
print(child)
for p in child.parameters(): root_print(rank, optimizer.state[p])

def main():
## MPI information
comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -622,6 +634,9 @@ def main():

root_print(rank, 'Starting training...')

assert args.ni_interpolate_momentum in ['True', 'False']
interpolate_momentum = args.ni_interpolate_momentum == 'True'

previous_model = None
patience = 10
previous_model_best_accuracy = -1.
Expand All @@ -633,28 +648,24 @@ def main():
)
optimizer = get_optimizer(model, args)

# root_print(rank, 'before')
# for p in model.parameters():
# root_print(rank, p.ravel()[:3].tolist() + p.ravel()[-3:].tolist())
# if level == 0:
# for p in previous_model.parameters():
# root_print(rank, previous_optimizer.state[p])
# root_print(rank, 'f')
# for p in model.parameters():
# root_print(rank, optimizer.state[p])
# root_print(rank, 'before')
# print_weights(rank, previous_model)
# print_momentum(rank, previous_model, previous_optimizer)

if previous_model is not None:
interpolate_weights(
previous_model, model, args.ni_cfactor, rank, num_procs, comm,
previous_optimizer, optimizer, True, args.ni_interpolation,
coarse_model=previous_model, fine_model=model, cf=args.ni_cfactor,
rank=rank, num_procs=num_procs, comm=comm,
interpolation_mode=args.ni_interpolation,
interpolate_momentum=interpolate_momentum,
coarse_optimizer=previous_optimizer, fine_optimizer=optimizer,
)

# root_print(rank, 'after')
# # for p in model.parameters():
# # root_print(rank, p.ravel()[:3].tolist() + p.ravel()[-3:].tolist())
# if level == 0:
# for p in model.parameters():
# root_print(rank, optimizer.state[p])
# root_print(rank, 'after')
# print_weights(rank, model)
# print_momentum(rank, model, optimizer)

root_print(
rank, f'Level={level}, ' \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,10 @@ def parse_args():
parser.add_argument('--debug', action='store_true')

# parser.add_argument('--ni_starting_level', type=int, default=0)
parser.add_argument('--ni_cfactor' , type=int, default=2)
parser.add_argument('--ni_num_levels' , type=int, default=2)
parser.add_argument('--ni_interpolation', type=int, default='constant')
parser.add_argument('--ni_cfactor' , type=int, default=2)
parser.add_argument('--ni_num_levels' , type=int, default=2)
parser.add_argument('--ni_interpolation' , type=str, default='constant')
parser.add_argument('--ni_interpolate_momentum', type=str, default='True')

##
# Do some parameter checking
Expand Down

0 comments on commit 5cf97cc

Please sign in to comment.