-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathspecial_einsum.py
56 lines (48 loc) · 1.17 KB
/
special_einsum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) 2015 James Hensman
# Licensed under the BSD 3-clause license (see LICENSE.txt)
import numpy as np
from scipy import weave
"""
This file provides a weavified version of the function
np.einsum('ij,ik,il->jkl', A, A, B)
see also test_special_einsum.py
"""
code = """
int n,m,mm,d;
double tmp;
for(n=0;n<N;n++){
for(m=0;m<M;m++){
//compute diag
tmp = A(n,m)*A(n,m);
for(d=0;d<D;d++){
res(m,m,d) += tmp*B(n,d);
}
//only compute in lower half
for(mm=0;mm<m;mm++){
tmp = A(n,m)*A(n,mm);
for(d=0;d<D;d++){
res(m,mm,d) += tmp*B(n,d);
}
}
}
}
//make symmpetrical
for(m=0;m<M;m++){
for(mm=0;mm<m;mm++){
for(d=0;d<D;d++){
res(mm,m,d) = res(m,mm,d);
}
}
}
"""
def special_einsum(A,B):
opts = {'headers' : ['<omp.h>'],
'extra_compile_args': ['-fopenmp -O3'],
'extra_link_args' : ['-lgomp'],
'libraries': ['gomp']}
N, M = A.shape
N2, D = B.shape
assert N==N2
res = np.zeros((M, M, D))
weave.inline(code, ['N','M','D','res','A','B'], type_converters=weave.converters.blitz, support_code='#include <omp.h>', **opts)
return res