-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Add EGES example (dmlc#3756)
* add eges example * remove csv files and add data link * Update README.md * Update main.py * Update model.py * Update sampler.py * Update utils.py * Update model.py Co-authored-by: Quan (Andy) Gan <[email protected]>
- Loading branch information
1 parent
a0bf5da
commit f5bba28
Showing
6 changed files
with
510 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# DGL & Pytorch implementation of Enhanced Graph Embedding with Side information (EGES) | ||
|
||
## Version | ||
dgl==0.6.1, torch==1.9.0 | ||
|
||
## Paper | ||
Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba: | ||
|
||
https://arxiv.org/pdf/1803.02349.pdf | ||
|
||
https://arxiv.org/abs/1803.02349 | ||
|
||
## How to run | ||
Create folder named `data`. Download two csv files from [here](https://github.com/Wang-Yu-Qing/dgl_data/tree/master/eges_data) into the `data` folder. | ||
|
||
Run command: `python main.py` with default configuration, and the following message will shown up: | ||
|
||
``` | ||
Using backend: pytorch | ||
Num skus: 33344, num brands: 3662, num shops: 4785, num cates: 79 | ||
Epoch 00000 | Step 00000 | Step Loss 0.9117 | Epoch Avg Loss: 0.9117 | ||
Epoch 00000 | Step 00100 | Step Loss 0.8736 | Epoch Avg Loss: 0.8801 | ||
Epoch 00000 | Step 00200 | Step Loss 0.8975 | Epoch Avg Loss: 0.8785 | ||
Evaluate link prediction AUC: 0.6864 | ||
Epoch 00001 | Step 00000 | Step Loss 0.8695 | Epoch Avg Loss: 0.8695 | ||
Epoch 00001 | Step 00100 | Step Loss 0.8290 | Epoch Avg Loss: 0.8643 | ||
Epoch 00001 | Step 00200 | Step Loss 0.8012 | Epoch Avg Loss: 0.8604 | ||
Evaluate link prediction AUC: 0.6875 | ||
... | ||
Epoch 00029 | Step 00000 | Step Loss 0.7095 | Epoch Avg Loss: 0.7095 | ||
Epoch 00029 | Step 00100 | Step Loss 0.7248 | Epoch Avg Loss: 0.7139 | ||
Epoch 00029 | Step 00200 | Step Loss 0.7123 | Epoch Avg Loss: 0.7134 | ||
Evaluate link prediction AUC: 0.7084 | ||
``` | ||
|
||
The AUC of link-prediction task on test graph is computed after each epoch is done. | ||
|
||
## Reference | ||
https://github.com/nonva/eges | ||
|
||
https://github.com/wangzhegeek/EGES.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import dgl | ||
import torch as th | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
from sklearn import metrics | ||
|
||
import utils | ||
from model import EGES | ||
from sampler import Sampler | ||
|
||
|
||
def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates): | ||
sampler = Sampler( | ||
train_g, | ||
args.walk_length, | ||
args.num_walks, | ||
args.window_size, | ||
args.num_negative | ||
) | ||
# for each node in the graph, we sample pos and neg | ||
# pairs for it, and feed these sampled pairs into the model. | ||
# (nodes in the graph are of course batched before sampling) | ||
dataloader = DataLoader( | ||
th.arange(train_g.num_nodes()), | ||
# this is the batch_size of input nodes | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
collate_fn=lambda x: sampler.sample(x, sku_info) | ||
) | ||
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates) | ||
optimizer = optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
for epoch in range(args.epochs): | ||
epoch_total_loss = 0 | ||
for step, (srcs, dsts, labels) in enumerate(dataloader): | ||
# the batch size of output pairs is unfixed | ||
# TODO: shuffle the triples? | ||
srcs_embeds, dsts_embeds = model(srcs, dsts) | ||
loss = model.loss(srcs_embeds, dsts_embeds, labels) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
epoch_total_loss += loss.item() | ||
|
||
if step % args.log_every == 0: | ||
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format( | ||
epoch, step, loss.item(), epoch_total_loss / (step + 1))) | ||
|
||
eval(model, test_g, sku_info) | ||
|
||
return model | ||
|
||
|
||
def eval(model, test_graph, sku_info): | ||
preds, labels = [], [] | ||
for edge in test_graph: | ||
src = th.tensor(sku_info[edge.src.numpy()[0]]).view(1, 4) | ||
dst = th.tensor(sku_info[edge.dst.numpy()[0]]).view(1, 4) | ||
# (1, dim) | ||
src = model.query_node_embed(src) | ||
dst = model.query_node_embed(dst) | ||
# (1, dim) -> (1, dim) -> (1, ) | ||
logit = th.sigmoid(th.sum(src * dst)) | ||
preds.append(logit.detach().numpy().tolist()) | ||
labels.append(edge.label) | ||
|
||
fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1) | ||
|
||
print("Evaluate link prediction AUC: {:.4f}".format(metrics.auc(fpr, tpr))) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = utils.init_args() | ||
|
||
valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data) | ||
|
||
g, sku_encoder, sku_decoder = utils.construct_graph( | ||
args.action_data, | ||
args.session_interval_sec, | ||
valid_sku_raw_ids | ||
) | ||
|
||
train_g, test_g = utils.split_train_test_graph(g) | ||
|
||
sku_info_encoder, sku_info_decoder, sku_info = \ | ||
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder) | ||
|
||
num_skus = len(sku_encoder) | ||
num_brands = len(sku_info_encoder["brand"]) | ||
num_shops = len(sku_info_encoder["shop"]) | ||
num_cates = len(sku_info_encoder["cate"]) | ||
|
||
print( | ||
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\ | ||
format(num_skus, num_brands, num_shops, num_cates) | ||
) | ||
|
||
model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch as th | ||
|
||
|
||
class EGES(th.nn.Module): | ||
def __init__(self, dim, num_nodes, num_brands, num_shops, num_cates): | ||
super(EGES, self).__init__() | ||
self.dim = dim | ||
# embeddings for nodes | ||
base_embeds = th.nn.Embedding(num_nodes, dim) | ||
brand_embeds = th.nn.Embedding(num_brands, dim) | ||
shop_embeds = th.nn.Embedding(num_shops, dim) | ||
cate_embeds = th.nn.Embedding(num_cates, dim) | ||
self.embeds = [base_embeds, brand_embeds, shop_embeds, cate_embeds] | ||
# weights for each node's side information | ||
self.side_info_weights = th.nn.Embedding(num_nodes, 4) | ||
|
||
def forward(self, srcs, dsts): | ||
# srcs: sku_id, brand_id, shop_id, cate_id | ||
srcs = self.query_node_embed(srcs) | ||
dsts = self.query_node_embed(dsts) | ||
|
||
return srcs, dsts | ||
|
||
def query_node_embed(self, nodes): | ||
""" | ||
@nodes: tensor of shape (batch_size, num_side_info) | ||
""" | ||
batch_size = nodes.shape[0] | ||
# query side info weights, (batch_size, 4) | ||
side_info_weights = th.exp(self.side_info_weights(nodes[:, 0])) | ||
# merge all embeddings | ||
side_info_weighted_embeds_sum = [] | ||
side_info_weights_sum = [] | ||
for i in range(4): | ||
# weights for i-th side info, (batch_size, ) -> (batch_size, 1) | ||
i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1)) | ||
# batch of i-th side info embedding * its weight, (batch_size, dim) | ||
side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i])) | ||
side_info_weights_sum.append(i_th_side_info_weights) | ||
# stack: (batch_size, 4, dim), sum: (batch_size, dim) | ||
side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1) | ||
# stack: (batch_size, 4), sum: (batch_size, ) | ||
side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1) | ||
# (batch_size, dim) | ||
H = side_info_weighted_embeds_sum / side_info_weights_sum | ||
|
||
return H | ||
|
||
def loss(self, srcs, dsts, labels): | ||
dots = th.sigmoid(th.sum(srcs * dsts, axis=1)) | ||
dots = th.clamp(dots, min=1e-7, max=1 - 1e-7) | ||
|
||
return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import dgl | ||
import numpy as np | ||
import torch as th | ||
|
||
|
||
class Sampler: | ||
def __init__(self, | ||
graph, | ||
walk_length, | ||
num_walks, | ||
window_size, | ||
num_negative): | ||
self.graph = graph | ||
self.walk_length = walk_length | ||
self.num_walks = num_walks | ||
self.window_size = window_size | ||
self.num_negative = num_negative | ||
self.node_weights = self.compute_node_sample_weight() | ||
|
||
def sample(self, batch, sku_info): | ||
""" | ||
Given a batch of target nodes, sample postive | ||
pairs and negative pairs from the graph | ||
""" | ||
batch = np.repeat(batch, self.num_walks) | ||
|
||
pos_pairs = self.generate_pos_pairs(batch) | ||
neg_pairs = self.generate_neg_pairs(pos_pairs) | ||
|
||
# get sku info with id | ||
srcs, dsts, labels = [], [], [] | ||
for pair in pos_pairs + neg_pairs: | ||
src, dst, label = pair | ||
src_info = sku_info[src] | ||
dst_info = sku_info[dst] | ||
|
||
srcs.append(src_info) | ||
dsts.append(dst_info) | ||
labels.append(label) | ||
|
||
return th.tensor(srcs), th.tensor(dsts), th.tensor(labels) | ||
|
||
def filter_padding(self, traces): | ||
for i in range(len(traces)): | ||
traces[i] = [x for x in traces[i] if x != -1] | ||
|
||
def generate_pos_pairs(self, nodes): | ||
""" | ||
For seq [1, 2, 3, 4] and node NO.2, | ||
the window_size=1 will generate: | ||
(1, 2) and (2, 3) | ||
""" | ||
# random walk | ||
traces, types = dgl.sampling.random_walk( | ||
g=self.graph, | ||
nodes=nodes, | ||
length=self.walk_length, | ||
prob="weight" | ||
) | ||
traces = traces.tolist() | ||
self.filter_padding(traces) | ||
|
||
# skip-gram | ||
pairs = [] | ||
for trace in traces: | ||
for i in range(len(trace)): | ||
center = trace[i] | ||
left = max(0, i - self.window_size) | ||
right = min(len(trace), i + self.window_size + 1) | ||
pairs.extend([[center, x, 1] for x in trace[left:i]]) | ||
pairs.extend([[center, x, 1] for x in trace[i+1:right]]) | ||
|
||
return pairs | ||
|
||
def compute_node_sample_weight(self): | ||
""" | ||
Using node degree as sample weight | ||
""" | ||
return self.graph.in_degrees().float() | ||
|
||
def generate_neg_pairs(self, pos_pairs): | ||
""" | ||
Sample based on node freq in traces, frequently shown | ||
nodes will have larger chance to be sampled as | ||
negative node. | ||
""" | ||
# sample `self.num_negative` neg dst node | ||
# for each pos node pair's src node. | ||
negs = th.multinomial( | ||
self.node_weights, | ||
len(pos_pairs) * self.num_negative, | ||
replacement=True | ||
).tolist() | ||
|
||
tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative) | ||
assert(len(tar) == len(negs)) | ||
neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)] | ||
|
||
return neg_pairs |
Oops, something went wrong.