Skip to content

Commit

Permalink
[Example] Add EGES example (dmlc#3756)
Browse files Browse the repository at this point in the history
* add eges example

* remove csv files and add data link

* Update README.md

* Update main.py

* Update model.py

* Update sampler.py

* Update utils.py

* Update model.py

Co-authored-by: Quan (Andy) Gan <[email protected]>
  • Loading branch information
Wang-Yu-Qing and BarclayII authored Apr 14, 2022
1 parent a0bf5da commit f5bba28
Show file tree
Hide file tree
Showing 6 changed files with 510 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/pytorch/eges/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__

41 changes: 41 additions & 0 deletions examples/pytorch/eges/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# DGL & Pytorch implementation of Enhanced Graph Embedding with Side information (EGES)

## Version
dgl==0.6.1, torch==1.9.0

## Paper
Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba:

https://arxiv.org/pdf/1803.02349.pdf

https://arxiv.org/abs/1803.02349

## How to run
Create folder named `data`. Download two csv files from [here](https://github.com/Wang-Yu-Qing/dgl_data/tree/master/eges_data) into the `data` folder.

Run command: `python main.py` with default configuration, and the following message will shown up:

```
Using backend: pytorch
Num skus: 33344, num brands: 3662, num shops: 4785, num cates: 79
Epoch 00000 | Step 00000 | Step Loss 0.9117 | Epoch Avg Loss: 0.9117
Epoch 00000 | Step 00100 | Step Loss 0.8736 | Epoch Avg Loss: 0.8801
Epoch 00000 | Step 00200 | Step Loss 0.8975 | Epoch Avg Loss: 0.8785
Evaluate link prediction AUC: 0.6864
Epoch 00001 | Step 00000 | Step Loss 0.8695 | Epoch Avg Loss: 0.8695
Epoch 00001 | Step 00100 | Step Loss 0.8290 | Epoch Avg Loss: 0.8643
Epoch 00001 | Step 00200 | Step Loss 0.8012 | Epoch Avg Loss: 0.8604
Evaluate link prediction AUC: 0.6875
...
Epoch 00029 | Step 00000 | Step Loss 0.7095 | Epoch Avg Loss: 0.7095
Epoch 00029 | Step 00100 | Step Loss 0.7248 | Epoch Avg Loss: 0.7139
Epoch 00029 | Step 00200 | Step Loss 0.7123 | Epoch Avg Loss: 0.7134
Evaluate link prediction AUC: 0.7084
```

The AUC of link-prediction task on test graph is computed after each epoch is done.

## Reference
https://github.com/nonva/eges

https://github.com/wangzhegeek/EGES.git
100 changes: 100 additions & 0 deletions examples/pytorch/eges/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import dgl
import torch as th
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics

import utils
from model import EGES
from sampler import Sampler


def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
sampler = Sampler(
train_g,
args.walk_length,
args.num_walks,
args.window_size,
args.num_negative
)
# for each node in the graph, we sample pos and neg
# pairs for it, and feed these sampled pairs into the model.
# (nodes in the graph are of course batched before sampling)
dataloader = DataLoader(
th.arange(train_g.num_nodes()),
# this is the batch_size of input nodes
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: sampler.sample(x, sku_info)
)
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(args.epochs):
epoch_total_loss = 0
for step, (srcs, dsts, labels) in enumerate(dataloader):
# the batch size of output pairs is unfixed
# TODO: shuffle the triples?
srcs_embeds, dsts_embeds = model(srcs, dsts)
loss = model.loss(srcs_embeds, dsts_embeds, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_total_loss += loss.item()

if step % args.log_every == 0:
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)))

eval(model, test_g, sku_info)

return model


def eval(model, test_graph, sku_info):
preds, labels = [], []
for edge in test_graph:
src = th.tensor(sku_info[edge.src.numpy()[0]]).view(1, 4)
dst = th.tensor(sku_info[edge.dst.numpy()[0]]).view(1, 4)
# (1, dim)
src = model.query_node_embed(src)
dst = model.query_node_embed(dst)
# (1, dim) -> (1, dim) -> (1, )
logit = th.sigmoid(th.sum(src * dst))
preds.append(logit.detach().numpy().tolist())
labels.append(edge.label)

fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)

print("Evaluate link prediction AUC: {:.4f}".format(metrics.auc(fpr, tpr)))


if __name__ == "__main__":
args = utils.init_args()

valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)

g, sku_encoder, sku_decoder = utils.construct_graph(
args.action_data,
args.session_interval_sec,
valid_sku_raw_ids
)

train_g, test_g = utils.split_train_test_graph(g)

sku_info_encoder, sku_info_decoder, sku_info = \
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder)

num_skus = len(sku_encoder)
num_brands = len(sku_info_encoder["brand"])
num_shops = len(sku_info_encoder["shop"])
num_cates = len(sku_info_encoder["cate"])

print(
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\
format(num_skus, num_brands, num_shops, num_cates)
)

model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates)
53 changes: 53 additions & 0 deletions examples/pytorch/eges/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch as th


class EGES(th.nn.Module):
def __init__(self, dim, num_nodes, num_brands, num_shops, num_cates):
super(EGES, self).__init__()
self.dim = dim
# embeddings for nodes
base_embeds = th.nn.Embedding(num_nodes, dim)
brand_embeds = th.nn.Embedding(num_brands, dim)
shop_embeds = th.nn.Embedding(num_shops, dim)
cate_embeds = th.nn.Embedding(num_cates, dim)
self.embeds = [base_embeds, brand_embeds, shop_embeds, cate_embeds]
# weights for each node's side information
self.side_info_weights = th.nn.Embedding(num_nodes, 4)

def forward(self, srcs, dsts):
# srcs: sku_id, brand_id, shop_id, cate_id
srcs = self.query_node_embed(srcs)
dsts = self.query_node_embed(dsts)

return srcs, dsts

def query_node_embed(self, nodes):
"""
@nodes: tensor of shape (batch_size, num_side_info)
"""
batch_size = nodes.shape[0]
# query side info weights, (batch_size, 4)
side_info_weights = th.exp(self.side_info_weights(nodes[:, 0]))
# merge all embeddings
side_info_weighted_embeds_sum = []
side_info_weights_sum = []
for i in range(4):
# weights for i-th side info, (batch_size, ) -> (batch_size, 1)
i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1))
# batch of i-th side info embedding * its weight, (batch_size, dim)
side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i]))
side_info_weights_sum.append(i_th_side_info_weights)
# stack: (batch_size, 4, dim), sum: (batch_size, dim)
side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1)
# stack: (batch_size, 4), sum: (batch_size, )
side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1)
# (batch_size, dim)
H = side_info_weighted_embeds_sum / side_info_weights_sum

return H

def loss(self, srcs, dsts, labels):
dots = th.sigmoid(th.sum(srcs * dsts, axis=1))
dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)

return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots)))
99 changes: 99 additions & 0 deletions examples/pytorch/eges/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import dgl
import numpy as np
import torch as th


class Sampler:
def __init__(self,
graph,
walk_length,
num_walks,
window_size,
num_negative):
self.graph = graph
self.walk_length = walk_length
self.num_walks = num_walks
self.window_size = window_size
self.num_negative = num_negative
self.node_weights = self.compute_node_sample_weight()

def sample(self, batch, sku_info):
"""
Given a batch of target nodes, sample postive
pairs and negative pairs from the graph
"""
batch = np.repeat(batch, self.num_walks)

pos_pairs = self.generate_pos_pairs(batch)
neg_pairs = self.generate_neg_pairs(pos_pairs)

# get sku info with id
srcs, dsts, labels = [], [], []
for pair in pos_pairs + neg_pairs:
src, dst, label = pair
src_info = sku_info[src]
dst_info = sku_info[dst]

srcs.append(src_info)
dsts.append(dst_info)
labels.append(label)

return th.tensor(srcs), th.tensor(dsts), th.tensor(labels)

def filter_padding(self, traces):
for i in range(len(traces)):
traces[i] = [x for x in traces[i] if x != -1]

def generate_pos_pairs(self, nodes):
"""
For seq [1, 2, 3, 4] and node NO.2,
the window_size=1 will generate:
(1, 2) and (2, 3)
"""
# random walk
traces, types = dgl.sampling.random_walk(
g=self.graph,
nodes=nodes,
length=self.walk_length,
prob="weight"
)
traces = traces.tolist()
self.filter_padding(traces)

# skip-gram
pairs = []
for trace in traces:
for i in range(len(trace)):
center = trace[i]
left = max(0, i - self.window_size)
right = min(len(trace), i + self.window_size + 1)
pairs.extend([[center, x, 1] for x in trace[left:i]])
pairs.extend([[center, x, 1] for x in trace[i+1:right]])

return pairs

def compute_node_sample_weight(self):
"""
Using node degree as sample weight
"""
return self.graph.in_degrees().float()

def generate_neg_pairs(self, pos_pairs):
"""
Sample based on node freq in traces, frequently shown
nodes will have larger chance to be sampled as
negative node.
"""
# sample `self.num_negative` neg dst node
# for each pos node pair's src node.
negs = th.multinomial(
self.node_weights,
len(pos_pairs) * self.num_negative,
replacement=True
).tolist()

tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)
assert(len(tar) == len(negs))
neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]

return neg_pairs
Loading

0 comments on commit f5bba28

Please sign in to comment.