forked from soumith/matio-ffi.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinit.lua
201 lines (167 loc) · 5.1 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
require 'torch'
local ffi = require 'ffi'
local mat = require 'matio.ffi'
local matio = {}
matio.ffi = mat
-- optional setting: loads lua strings instead of CharTensor
matio.use_lua_strings = false
-- mapping of MAT matrix types into torch tensor
local tensor_types_mapper = {
[mat.C_CHAR] = {constructor='CharTensor', sizeof=1},
[mat.C_INT8] = {constructor='CharTensor', sizeof=1},
[mat.C_UINT8] = {constructor='ByteTensor', sizeof=1},
[mat.C_INT16] = {constructor='ShortTensor', sizeof=2},
[mat.C_UINT16] = {constructor='ShortTensor', sizeof=2},
[mat.C_INT32] = {constructor='IntTensor', sizeof=4},
[mat.C_UINT32] = {constructor='IntTensor', sizeof=4},
[mat.C_INT64] = {constructor='LongTensor', sizeof=8},
[mat.C_UINT64] = {constructor='LongTensor', sizeof=8},
[mat.C_SINGLE] = {constructor='FloatTensor', sizeof=4},
[mat.C_DOUBLE] = {constructor='DoubleTensor',sizeof=8}
}
local function load_tensor(file, var)
local out
local sizeof
-- type check
local mapper = tensor_types_mapper[tonumber(var.class_type)]
if mapper then
out = torch[mapper.constructor]()
sizeof = mapper.sizeof
else
print('Unsupported type of tensor: ' .. var.class_type)
return nil
end
-- rank check
if var.rank > 8 or var.rank < 1 then
print('Rank of input matrix is invalid: ' .. var.rank)
return nil
end
local sizes = {}
for i=0,var.rank-1 do
table.insert(sizes, tonumber(var.dims[i]))
end
-- reverse initialize because of column-major order of matlab
local revsizes = {}
for i=1, var.rank do
revsizes[i] = sizes[var.rank-i+1]
end
-- resize tensor
out:resize(torch.LongStorage(revsizes))
-- memcpy
ffi.copy(out:data(), var.data, out:nElement() * sizeof);
mat.varFree(var);
-- transpose, because matlab is column-major
if out:dim() > 1 then
for i=1,math.floor(out:dim()/2) do
out=out:transpose(i, out:dim()-i+1)
end
end
return out
end
local function load_struct(file, var)
local out = {}
n_fields = mat.varGetNumberOfFields(var)
field_names = mat.varGetStructFieldnames(var)
for i=0,n_fields-1 do
field_name = ffi.string(field_names[i])
field_value = mat.varGetStructFieldByIndex(var, i, 0)
out[field_name] = mat_read_variable(file, field_value)
end
return out
end
local function load_cell(file, var)
local out = {};
local index = 0
while true do
cell = mat.varGetCell(var, index)
if cell == nil then
break
end
index = index + 1
-- using array index starting at 1 (lua standard)
out[index] = mat_read_variable(file, cell)
end
return out
end
local function load_string(file, var)
return ffi.string(var.data)
end
function mat_read_variable(file, var)
-- will load C_CHAR sequence as a lua string, instead of torch tensor
if matio.use_lua_strings == true and var.class_type == mat.C_CHAR then
return load_string(file, var)
end
if tensor_types_mapper[tonumber(var.class_type)] then
return load_tensor(file, var)
end
if var.class_type == mat.C_CELL then
return load_cell(file, var)
end
if var.class_type == mat.C_STRUCT then
return load_struct(file, var)
end
print('Unsupported variable type: ' .. var.class_type)
return nil
end
--[[
Load all variables (or just the requested ones) from a given .mat file
It supports structs, cell arrays and tensors of the appropriate types.
Sequences of characters can optionally be loaded as lua strings instead
of torch CharTensors.
matio.load(filename, variableName)
matio.load(filename)
matio.load(filename,{'var1','var2'})
Example:
local img1 = matio.load('data.mat', 'img1')
--]]
function matio.load(filename, name)
local file = mat.open(filename, mat.ACC_RDONLY);
if file == nil then
print('File could not be opened: ' .. filename)
return
end
local names
local string_name
-- if name is not given then load everything
if not name then
names = {}
elseif type(name) == 'string' then
names = {name}
string_name = true
elseif type(name) == 'table' then
names = name
end
if #names == 0 then
-- go over the file and get the names
local var = mat.varReadNextInfo(file)
while var ~= nil do
var_name_str = ffi.string(var.name)
table.insert(names, var_name_str)
var = mat.varReadNextInfo(file)
end
end
if #names == 0 then
print('No variables in this file')
return
end
local out = {}
for i, varname in ipairs(names) do
local var = mat.varRead(file, varname);
if var ~= nil then
local x = mat_read_variable(file, var)
if x ~= nil then
out[varname] = x
end
else
print('Could not find variable with name: ' .. name .. ' in file: ' .. ffi.string(mat.getFilename(file)))
end
end
mat.close(file)
-- conserve backward compatibility
if #names == 1 and string_name then
return out[names[1]]
else
return out
end
end
return matio