Skip to content

Commit

Permalink
Fix batching.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Jan 12, 2025
1 parent 3435086 commit 62356d9
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
5 changes: 4 additions & 1 deletion alignn/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def collate(samples: List[Tuple[dgl.DGLGraph, torch.Tensor]]):
graphs, lattices, labels = map(list, zip(*samples))
# graphs, lgs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.stack(lattices), torch.tensor(labels)
if len(labels[0].size()) > 0:
return batched_graph, torch.stack(lattices), torch.stack(labels)
else:
return batched_graph, torch.stack(lattices), torch.tensor(labels)

@staticmethod
def collate_line_graph(
Expand Down
1 change: 1 addition & 0 deletions alignn/models/ealignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def forward(
stress = self.config.stress_multiplier * torch.stack(stresses)
if self.classification:
out = self.softmax(out)
# print('out',out)
result["out"] = out
result["additional"] = additional_out
result["grad"] = forces
Expand Down
96 changes: 96 additions & 0 deletions alignn/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,102 @@ def lightweight_line_graph(
filter_condition: Callable[[torch.Tensor], torch.Tensor],
) -> dgl.DGLGraph:
"""Make the line graphs lightweight with preserved node ordering.
Handles both batched and unbatched graphs.
Args:
input_graph: Input DGL graph (can be batched)
feature_name: Name of the edge feature to filter on
filter_condition: Function that takes edge features and returns boolean mask
Returns:
New DGL graph with filtered edges while preserving original node ordering
"""
# Check if graph is batched
is_batched = (
hasattr(input_graph, "batch_size") and input_graph.batch_size > 1
)

if is_batched:
# Get the batch size and number of nodes per graph
batch_size = input_graph.batch_size
graph_list = dgl.unbatch(input_graph)
processed_graphs = []

# Process each graph individually
for g in graph_list:
# Get active edges for this graph
g_active_edges = torch.logical_not(
filter_condition(g.edata[feature_name])
)

# Get filtered edges
g_src, g_dst = g.edges()
g_src = g_src[g_active_edges]
g_dst = g_dst[g_active_edges]

# Get edge IDs for active edges
g_edge_ids = g_active_edges.nonzero().squeeze()

# Create new graph with same number of nodes
new_g = dgl.graph(
(g_src, g_dst),
num_nodes=g.num_nodes(),
device=input_graph.device,
)

# Copy edge IDs
new_g.edata["edge_ids"] = g_edge_ids

# Copy node features
for node_feature, node_value in g.ndata.items():
new_g.ndata[node_feature] = node_value

# Copy filtered edge features
for edge_feature, edge_value in g.edata.items():
new_g.edata[edge_feature] = edge_value[g_active_edges]

processed_graphs.append(new_g)

# Batch the processed graphs back together
return dgl.batch(processed_graphs)

else:
# Handle single graph case (original implementation)
active_edges = torch.logical_not(
filter_condition(input_graph.edata[feature_name])
)

source_nodes, destination_nodes = input_graph.edges()
source_nodes, destination_nodes = (
source_nodes[active_edges],
destination_nodes[active_edges],
)

edge_ids = active_edges.nonzero().squeeze()

new_graph = dgl.graph(
(source_nodes, destination_nodes),
num_nodes=input_graph.num_nodes(),
device=input_graph.device,
)

new_graph.edata["edge_ids"] = edge_ids

for node_feature, node_value in input_graph.ndata.items():
new_graph.ndata[node_feature] = node_value

for edge_feature, edge_value in input_graph.edata.items():
new_graph.edata[edge_feature] = edge_value[active_edges]

return new_graph


def lightweight_line_graph1(
input_graph: dgl.DGLGraph,
feature_name: str,
filter_condition: Callable[[torch.Tensor], torch.Tensor],
) -> dgl.DGLGraph:
"""Make the line graphs lightweight with preserved node ordering.
Args:
input_graph: Input DGL graph
Expand Down

0 comments on commit 62356d9

Please sign in to comment.