-
Notifications
You must be signed in to change notification settings - Fork 23
/
multipart-post.lua
175 lines (157 loc) · 4.69 KB
/
multipart-post.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
local ltn12 = require "ltn12"
local url = require "socket.url"
local unpack = table.unpack or unpack
local _M = {}
_M.CHARSET = "UTF-8"
_M.LANGUAGE = ""
local function fmt(p, ...)
if select('#', ...) == 0 then
return p
end
return string.format(p, ...)
end
local function tprintf(t, p, ...)
t[#t+1] = fmt(p, ...)
end
local function section_header(r, k, extra)
tprintf(r, "content-disposition: form-data; name=\"%s\"", k)
if extra.filename then
tprintf(r, "; filename=\"%s\"", extra.filename)
tprintf(
r, "; filename*=%s'%s'%s",
_M.CHARSET, _M.LANGUAGE, url.escape(extra.filename)
)
end
if extra.content_type then
tprintf(r, "\r\ncontent-type: %s", extra.content_type)
end
if extra.content_transfer_encoding then
tprintf(
r, "\r\ncontent-transfer-encoding: %s",
extra.content_transfer_encoding
)
end
tprintf(r, "\r\n\r\n")
end
local function gen_boundary()
local t = {"BOUNDARY-"}
for i=2,17 do t[i] = string.char(math.random(65, 90)) end
t[18] = "-BOUNDARY"
return table.concat(t)
end
local function encode_header_to_table(r, k, v, boundary)
local _t = type(v)
tprintf(r, "--%s\r\n", boundary)
if _t == "string" then
section_header(r, k, {})
elseif _t == "table" then
assert(v.data, "invalid input")
local extra = {
filename = v.filename or v.name,
content_type = v.content_type or v.mimetype
or "application/octet-stream",
content_transfer_encoding = v.content_transfer_encoding
or "binary",
}
section_header(r, k, extra)
else
error(string.format("unexpected type %s", _t))
end
end
local function encode_header_as_source(k, v, boundary, ctx)
local r = {}
encode_header_to_table(r, k, v, boundary, ctx)
local s = table.concat(r)
if ctx then
ctx.headers_length = ctx.headers_length + #s
end
return ltn12.source.string(s)
end
local function data_len(d)
local _t = type(d)
if _t == "string" then
return string.len(d)
elseif _t == "table" then
if type(d.data) == "string" then
return string.len(d.data)
end
if d.len then return d.len end
error("must provide data length for non-string datatypes")
end
end
local function content_length(t, boundary, ctx)
local r = ctx and ctx.headers_length or 0
for k, v in pairs(t) do
if not ctx then
local tmp = {}
encode_header_to_table(tmp, k, v, boundary)
r = r + #table.concat(tmp)
end
r = r + data_len(v) + 2; -- `\r\n`
end
return r + #boundary + 6; -- `--BOUNDARY--\r\n`
end
local function get_data_src(v)
local _t = type(v)
if v.source then
return v.source
elseif _t == "string" then
return ltn12.source.string(v)
elseif _t == "table" then
_t = type(v.data)
if _t == "string" then
return ltn12.source.string(v.data)
elseif _t == "table" then
return ltn12.source.table(v.data)
elseif _t == "userdata" then
return ltn12.source.file(v.data)
elseif _t == "function" then
return v.data
end
end
error("invalid input")
end
local function set_ltn12_blksz(sz)
assert(type(sz) == "number", "set_ltn12_blksz expects a number")
ltn12.BLOCKSIZE = sz
end
_M.set_ltn12_blksz = set_ltn12_blksz
local function source(t, boundary, ctx)
local sources, n = {}, 1
for k, v in pairs(t) do
sources[n] = encode_header_as_source(k, v, boundary, ctx)
sources[n+1] = get_data_src(v)
sources[n+2] = ltn12.source.string("\r\n")
n = n + 3
end
sources[n] = ltn12.source.string(string.format("--%s--\r\n", boundary))
return ltn12.source.cat(unpack(sources))
end
_M.source = source
function _M.gen_request(t)
local boundary = gen_boundary()
-- This is an optimization to avoid re-encoding headers twice.
-- The length of the headers is stored when computing the source,
-- and re-used when computing the content length.
local ctx = {headers_length = 0}
return {
method = "POST",
source = source(t, boundary, ctx),
headers = {
["content-length"] = content_length(t, boundary, ctx),
["content-type"] = fmt(
"multipart/form-data; boundary=%s", boundary
),
},
}
end
function _M.encode(t, boundary)
boundary = boundary or gen_boundary()
local r = {}
assert(ltn12.pump.all(
(source(t, boundary)),
(ltn12.sink.table(r))
))
return table.concat(r), boundary
end
return _M