forked from torch/torch7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFI.lua
205 lines (180 loc) · 5.97 KB
/
FFI.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
local ok, ffi = pcall(require, 'ffi')
local function checkArgument(condition, fn, ud, msg, level)
local level = level or 3
if not condition then
error("bad argument #" .. ud .. " to '" .. fn .. "' (" .. msg .. ")", level)
end
end
local function checkArgumentType(expected, actual, fn, ud, level)
local level = level or 3
if expected ~= actual then
checkArgument(false, fn, ud, expected .. " expected, got " .. actual, level + 1)
end
end
if ok then
local Real2real = {
Byte='unsigned char',
Char='char',
Short='short',
Int='int',
Long='long',
Float='float',
Double='double'
}
-- Allocator
ffi.cdef[[
typedef struct THAllocator {
void* (*malloc)(void*, long);
void* (*realloc)(void*, void*, long);
void (*free)(void*, void*);
} THAllocator;
]]
-- Storage
for Real, real in pairs(Real2real) do
local cdefs = [[
typedef struct THRealStorage
{
real *data;
long size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
} THRealStorage;
]]
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
local Storage = torch.getmetatable(string.format('torch.%sStorage', Real))
local Storage_tt = ffi.typeof('TH' .. Real .. 'Storage**')
rawset(Storage,
"cdata",
function(self)
return Storage_tt(self)[0]
end)
rawset(Storage,
"data",
function(self)
return Storage_tt(self)[0].data
end)
end
-- Tensor
for Real, real in pairs(Real2real) do
local cdefs = [[
typedef struct THRealTensor
{
long *size;
long *stride;
int nDimension;
THRealStorage *storage;
long storageOffset;
int refcount;
char flag;
} THRealTensor;
]]
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)
local Tensor = torch.getmetatable(string.format('torch.%sTensor', Real))
local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')
rawset(Tensor,
"cdata",
function(self)
if not self then return nil; end
return Tensor_tt(self)[0]
end)
rawset(Tensor,
"data",
function(self)
if not self then return nil; end
self = Tensor_tt(self)[0]
return self.storage ~= nil and self.storage.data + self.storageOffset or nil
end)
-- faster apply (contiguous case)
local apply = Tensor.apply
rawset(Tensor,
"apply",
function(self, func)
if self:isContiguous() and self.data then
local self_d = self:data()
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return apply(self, func)
end
end)
-- faster map (contiguous case)
local map = Tensor.map
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)
if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
assert(src:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map(self, src, func)
end
end)
-- faster map2 (contiguous case)
local map2 = Tensor.map2
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)
if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
local src2_d = src2:data()
assert(src1:nElement() == self:nElement(), 'size mismatch')
assert(src2:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map2(self, src1, src2, func)
end
end)
end
-- torch.data
-- will fail if :data() is not defined
function torch.data(self, asnumber)
if not self then return nil; end
local data = self:data()
if asnumber then
return ffi.cast('intptr_t', data)
else
return data
end
end
-- torch.cdata
-- will fail if :cdata() is not defined
function torch.cdata(self, asnumber)
if not self then return nil; end
local cdata = self:cdata()
if asnumber then
return ffi.cast('intptr_t', cdata)
else
return cdata
end
end
end