Skip to content

Commit

Permalink
cleaning cruft
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 15, 2025
1 parent a8dc930 commit 7d6dac2
Showing 1 changed file with 0 additions and 51 deletions.
51 changes: 0 additions & 51 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,57 +31,6 @@
)


def masked_log_subtract(log_probs, mask, parent_indices, eps=1e-8):
"""
Calculates log(1 - sum(exp(log_probs_children))) in a numerically stable way.
Args:
log_probs: Tensor of shape [B, L, C] (batch, sequence length, classes)
mask: Boolean tensor of shape [B, L] indicating valid positions.
parent_indices: Tensor of shape [B, L] indicating parent indices.
eps: Small value for numerical stability.
Returns:
Tensor of shape [B, L, C] with parent log probabilities updated.
"""

# 1. Mask out parent positions in the log_probs
# Use a very negative value (-1e9 is typical, but be mindful of your logit range)
masked_log_probs = log_probs.masked_fill(
torch.zeros_like(log_probs, dtype=torch.bool)
.scatter_(-1, parent_indices.unsqueeze(-1), 1)
.to(device=log_probs.device, dtype=torch.bool),
-1e9,
)

# 2. Calculate log(sum(exp(log_probs_children))) using log-sum-exp trick
log_sum_exp_children = torch.logsumexp(masked_log_probs, dim=-1, keepdim=True)

# 3. Calculate log(1 - sum(exp(log_probs_children))) using a modified log-sum-exp trick:
# log(1 - x) = log(1 - exp(log(x)))
# = log(exp(0) - exp(log(x)))
# = log(exp(log(exp(0) - exp(log(x)))))
# = log(exp(log1mexp(log(x)))) (where log1mexp(a) = log(1 - exp(a)))
# We need to handle cases where log_sum_exp_children is close to 0.
# We use torch.where to select between direct computation and log1mexp approximation.

log_parent_probs = torch.where(
log_sum_exp_children > torch.log(torch.tensor(0.5)), # Using 0.5 as a threshold
torch.log(eps + 1.0 - torch.exp(log_sum_exp_children)),
torch.log1p(-torch.exp(log_sum_exp_children)),
)

# 4. Scatter the log_parent_probs back into the original log_probs
log_probs = log_probs.scatter(-1, parent_indices.unsqueeze(-1), log_parent_probs)

# 5. Apply the mask for valid positions
log_probs = torch.where(
mask.unsqueeze(-1), log_probs, torch.tensor(0.0, device=log_probs.device)
)

return log_probs


class DCSMDataset(DXSMDataset):

def __init__(
Expand Down

0 comments on commit 7d6dac2

Please sign in to comment.