Skip to content

Commit

Permalink
bl opt default
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jun 14, 2024
1 parent ecb19fc commit 6dd1fcc
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def mark_branch_lengths_optimized(self, cycle):
)

def joint_train(
self, epochs=100, cycle_count=2, training_method="full", out_prefix=None
self, epochs=100, cycle_count=2, training_method="full", out_prefix=None, optimize_bl_first_cycle=True
):
"""
Do joint optimization of model and branch lengths.
Expand All @@ -708,6 +708,10 @@ def joint_train(
If training_method is "yun", then we use Yun's approximation to the branch lengths.
If training_method is "fixed", then we fix the branch lengths and only optimize the model.
We give an option to optimize the branch lengths in the first cycle (by
default we do). But, this can be useful to turn off e.g. if we've loaded
in some preoptimized branch lengths.
We reset the optimization after each cycle, and we use a learning rate
schedule that uses a weighted geometric mean of the current learning
rate and the initial learning rate that progressively moves towards
Expand All @@ -722,7 +726,8 @@ def joint_train(
else:
raise ValueError(f"Unknown training method {training_method}")
loss_history_l = []
optimize_branch_lengths()
if optimize_bl_first_cycle:
optimize_branch_lengths()
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
print(
Expand Down

0 comments on commit 6dd1fcc

Please sign in to comment.