Skip to content

Commit

Permalink
fix: Use torch.round().long() for timestep comparison line 303 to han…
Browse files Browse the repository at this point in the history
…dle floating-point precision
  • Loading branch information
Liang-ZX committed Jan 15, 2025
1 parent bba59fb commit 1a23bad
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def set_timesteps(
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps


schedule_timesteps = torch.round(schedule_timesteps).long()
indices = (schedule_timesteps == timestep).nonzero()

# The sigma index that is taken for the **very** first `step`
Expand Down

0 comments on commit 1a23bad

Please sign in to comment.