-
Notifications
You must be signed in to change notification settings - Fork 0
/
CAMP_C.pyx
110 lines (87 loc) · 2.68 KB
/
CAMP_C.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 18 13:14:52 2014
@author: Rafael
"""
import numpy as np
import numpy.linalg as la
import scipy.linalg.blas
from cvxopt import matrix
cimport numpy as np
cimport cython
from blas_types cimport dgemm_t, zgemm_t, ddot_t, dgemv_t, zdotu_t, zgemv_t
from numpy cimport PyArray_ZEROS
from numpy cimport float64_t, ndarray, complex128_t, complex64_t
from numpy import float64, ndarray, complex128, complex64
ctypedef float64_t DOUBLE
ctypedef complex128_t dcomplex
ctypedef complex64_t COMPLEX64
cdef int FORTRAN = 1
cdef extern from "f2pyptr.h":
void *f2py_pointer(object) except NULL
np.import_array()
def get_func(name,dt):
return scipy.linalg.blas.get_blas_funcs(name, dtype=dt)._cpointer
cdef dgemm_t *dgemm = <dgemm_t*>f2py_pointer(get_func('gemm', float64))
cdef zgemm_t *zgemm = <zgemm_t*>f2py_pointer(get_func('gemm', complex128))
cdef ddot_t *ddot = <ddot_t*> f2py_pointer(get_func('dot', float64))
cdef dgemv_t *dgemv = <dgemv_t*>f2py_pointer(get_func('gemv', float64))
cdef zdotu_t *zdotu = <zdotu_t*>f2py_pointer(get_func('dotu', complex128))
cdef zgemv_t *zgemv = <zgemv_t*>f2py_pointer(get_func('gemv', complex128))
def STc(x,theta,copy=True):
if copy:
z=x.copy()
else:
z=x
az=abs(z)
z[az<theta]=0
z=z-theta*(z/az)
return z
def dSTc(x,theta):
eps=1e-12
z=x.copy()
az=np.abs(z)
az3=az**3+eps
x=np.real(z)
y=np.imag(z)
d1R=1-(theta*y**2)/az3
d2I=1-(theta*x**2)/az3
d1R[az<theta]=0
d2I[az<theta]=0
return (d1R,d2I)
def CAMP(A,y,beta,verbose=False):
A=np.asfortranarray(A)
y=np.asfortranarray(y)
M,N=A.shape
return CAMP_impl(A,y,beta,M,N,verbose)
cdef CAMP_impl(dcomplex[::1,:] A,
dcomplex[::1,:] y,
double beta,
int M,
int N,
int verbose):
x_old=np.zeros((N,1))
z=y.copy()
eps=1e-12
it=0
beta2=beta*1/np.sqrt(2)
while True:
tz=np.dot(A.T.conj(),z)+x_old
sigma_hat=beta2*np.median(abs(tz))
x=STc(tz,sigma_hat)
(dR,dI)=dSTc(tz,sigma_hat)
z=y-np.dot(A,x)+z*(np.sum(dR)+np.sum(dI))/(2*N)
n=la.norm(x-x_old,2)
if it>1000:
break
if n<=eps*la.norm(x,2):
break
if verbose:
it+=1
if (it%10)==0:
print "AMP iteration: %d (error %g, %g)"%(it,n,sigma_hat)
x_old=x
if verbose:
it+=1
print "AMP iteration: %d (error %g, %g)"%(it,n,sigma_hat)
return x