diff --git a/netam/framework.py b/netam/framework.py index 20a487a6..30c021bf 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -699,7 +699,7 @@ def mark_branch_lengths_optimized(self, cycle): ) def joint_train( - self, epochs=100, cycle_count=2, training_method="full", out_prefix=None + self, epochs=100, cycle_count=2, training_method="full", out_prefix=None, optimize_bl_first_cycle=True ): """ Do joint optimization of model and branch lengths. @@ -708,6 +708,10 @@ def joint_train( If training_method is "yun", then we use Yun's approximation to the branch lengths. If training_method is "fixed", then we fix the branch lengths and only optimize the model. + We give an option to optimize the branch lengths in the first cycle (by + default we do). But, this can be useful to turn off e.g. if we've loaded + in some preoptimized branch lengths. + We reset the optimization after each cycle, and we use a learning rate schedule that uses a weighted geometric mean of the current learning rate and the initial learning rate that progressively moves towards @@ -722,7 +726,8 @@ def joint_train( else: raise ValueError(f"Unknown training method {training_method}") loss_history_l = [] - optimize_branch_lengths() + if optimize_bl_first_cycle: + optimize_branch_lengths() self.mark_branch_lengths_optimized(0) for cycle in range(cycle_count): print(