Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 13, 2023
1 parent e1da7de commit 20cee8f
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
5 changes: 3 additions & 2 deletions python/dgllife/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

try:
import rdkit
from . import data
from . import utils
except ImportError:
print('RDKit is not installed, which is required for utils related to cheminformatics')

from . import data
from . import utils
4 changes: 2 additions & 2 deletions python/dgllife/model/gnn/attentivefp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward(self, g, edge_logits, edge_feats, node_feats):
"""
g = g.local_var()
g.edata['e'] = edge_softmax(g, edge_logits) * self.edge_transform(edge_feats)
g.update_all(fn.copy_edge('e', 'm'), fn.sum('m', 'c'))
g.update_all(fn.copy_e('e', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))

Expand Down Expand Up @@ -123,7 +123,7 @@ def forward(self, g, edge_logits, node_feats):
g.edata['a'] = edge_softmax(g, edge_logits)
g.ndata['hv'] = self.project_node(node_feats)

g.update_all(fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'c'))
g.update_all(fn.u_mul_e('hv', 'a', 'm'), fn.sum('m', 'c'))
context = F.elu(g.ndata['c'])
return F.relu(self.gru(context, node_feats))

Expand Down
2 changes: 1 addition & 1 deletion python/dgllife/model/gnn/weave.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, g, node_feats, edge_feats, node_only=False):
# Update node features
node_node_feats = self.activation(self.node_to_node(node_feats))
g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
g.update_all(fn.copy_e('e2n', 'm'), fn.sum('m', 'e2n'))
edge_node_feats = g.ndata.pop('e2n')
new_node_feats = self.activation(self.update_node(
torch.cat([node_node_feats, edge_node_feats], dim=1)))
Expand Down
4 changes: 2 additions & 2 deletions python/dgllife/model/gnn/wln.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def forward(self, g, node_feats, edge_feats):
if g.num_edges() > 0:
# The following lines do not work for a graph without edges.
g.ndata['hv'] = node_feats
g.apply_edges(fn.copy_src('hv', 'he_src'))
g.apply_edges(fn.copy_u('hv', 'he_src'))
concat_edge_feats = torch.cat([g.edata['he_src'], edge_feats], dim=1)
g.edata['he'] = self.project_concatenated_messages(concat_edge_feats)
g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new'))
g.update_all(fn.copy_e('he', 'm'), fn.sum('m', 'hv_new'))
node_feats = self.get_new_node_feats(
torch.cat([node_feats, g.ndata['hv_new']], dim=1))
else:
Expand Down
2 changes: 1 addition & 1 deletion python/dgllife/model/model_zoo/wln_reaction_center.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat):
self.project_node_pair_feature(node_pair_feat)
)
batch_complete_graphs.update_all(
fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context'))
fn.u_mul_e('hv', 'a', 'm'), fn.sum('m', 'context'))
node_contexts = batch_complete_graphs.ndata.pop('context')

return node_contexts
Expand Down

0 comments on commit 20cee8f

Please sign in to comment.