diff --git a/README.md b/README.md index d9361cf..dbfa4f4 100644 --- a/README.md +++ b/README.md @@ -3,34 +3,35 @@ [Ourls][1] 是由 [takashiki][2] 实现的一个基于发号和 hashid 的短网址服务。 受这个项目的启发,将此工程移植,使用 [OpenResty][3] 实现。 -### 待移植的功能: - - - url 规格化 - - 前置代理支持 - ### 待增加的特性: + - 工程目录优化,以 OpenResty 目录结构为准 - Cache 支持 ### 安装方法: - 安装 openresty rpm 包(或手动编译,建议使用 --prefix=/usr/local/openresty) + - 安装 libidn-devel 库(yum install libidn-devel) - 将原 openresty/nginx/conf 目录备份 - - 将本工程解压到 openresty/nginx/conf 目录 - - 进入 openresty/nginx/conf/ourl 目录,复制 config.sample.lua 为 config.lua - - 修改 config.lua 中的数据库等配置 + - 将本工程解压到 openresty/nginx/conf 目录,执行 `install.sh` + - 修改 ourl/config.lua 中的数据库等配置 - 恢复 urls.sql 至 mysql/mariadb 数据库 - - 进入 openresty/nginx/conf/vhosts 目录,修改 ourl.conf 中的 server_name - - 进入 openresty/nginx/conf/lib/hashids 目录,执行 make 命令 + - 进入 vhosts 目录,修改 ourl.conf 中的 server_name 为你自己的域名 - 启动 openresty ### 使用到的其他项目 - [leihog/hashids.lua][4] - [APItools/router.lua][5] + - [golgote/neturl][6] + - [mah0x211/lua-idna][7] + - [hamishforbes/lua-resty-iputils][8] [1]: https://github.com/takashiki/Ourls [2]: https://github.com/takashiki [3]: http://openresty.org/ [4]: https://github.com/leihog/hashids.lua - [5]: https://github.com/APItools/router.lua \ No newline at end of file + [5]: https://github.com/APItools/router.lua + [6]: https://github.com/golgote/neturl + [7]: https://github.com/mah0x211/lua-idna + [8]: https://github.com/hamishforbes/lua-resty-iputils \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..5d5a039 --- /dev/null +++ b/install.sh @@ -0,0 +1,7 @@ +#/bin/sh +BASE=$(dirname $(readlink -f ${0})) +cp "${BASE}/ourl/config.sample.lua" "${BASE}/ourl/config.lua" +cd "${BASE}/lib/hashids" && make +cd "${BASE}/lib/idna" && make linux +mv "${BASE}/lib/idna/idna.so" "${BASE}" +cd ${BASE} \ No newline at end of file diff --git a/lib/idna/Makefile b/lib/idna/Makefile new file mode 100644 index 0000000..77e9f55 --- /dev/null +++ b/lib/idna/Makefile @@ -0,0 +1,63 @@ +# This Makefile is based on lua-zlib's Makefile. Thanks to the lua-zlib developers. +# Inform the location to intall the modules +PREFIX ?= /usr/local/openresty +LUAPATH ?= $(PREFIX)/lualib +LUACPATH ?= $(PREFIX)/lualib +INCDIR ?= -I$(PREFIX)/luajit/include/luajit-2.1 +LIBDIR ?= -L$(PREFIX)/luajit/lib + +# For Mac OS X: set the system version +MACOSX_VERSION = 10.4 + +CMOD = idna.so +OBJS = idna.o + +LIBS = -lidn -lluajit-5.1 +WARN = -Wall -pedantic + +BSD_CFLAGS = -O2 -fPIC $(WARN) $(INCDIR) $(DEFS) +BSD_LDFLAGS = -O -shared -fPIC $(LIBDIR) + +LNX_CFLAGS = -O2 -fPIC $(WARN) $(INCDIR) $(DEFS) +LNX_LDFLAGS = -O -shared -fPIC $(LIBDIR) + +MAC_ENV = env MACOSX_DEPLOYMENT_TARGET='$(MACVER)' +MAC_CFLAGS = -O2 -fPIC -fno-common $(WARN) $(INCDIR) $(DEFS) +MAC_LDFLAGS = -bundle -undefined dynamic_lookup -fPIC $(LIBDIR) + +CC = gcc +LD = $(MYENV) gcc +CFLAGS = $(MYCFLAGS) +LDFLAGS = $(MYLDFLAGS) + +.PHONY: all clean install none linux bsd macosx + +all: + @echo "Usage: $(MAKE) " + @echo " * linux" + @echo " * bsd" + @echo " * macosx" + +install: $(CMOD) + cp $(CMOD) $(LUACPATH) + +uninstall: + rm $(LUACPATH)/$(CMOD) + +linux: + @$(MAKE) $(CMOD) MYCFLAGS="$(LNX_CFLAGS)" MYLDFLAGS="$(LNX_LDFLAGS)" INCDIR="$(INCDIR)" LIBDIR="$(LIBDIR)" DEFS="$(DEFS)" + +bsd: + @$(MAKE) $(CMOD) MYCFLAGS="$(BSD_CFLAGS)" MYLDFLAGS="$(BSD_LDFLAGS)" INCDIR="$(INCDIR)" LIBDIR="$(LIBDIR)" DEFS="$(DEFS)" + +macosx: + @$(MAKE) $(CMOD) MYCFLAGS="$(MAC_CFLAGS)" MYLDFLAGS="$(MAC_LDFLAGS)" MYENV="$(MAC_ENV)" INCDIR="$(INCDIR)" LIBDIR="$(LIBDIR)" DEFS="$(DEFS)" + +clean: + rm -f $(OBJS) $(CMOD) + +.c.o: + $(CC) -c $(CFLAGS) $(DEFS) $(INCDIR) -o $@ $< + +$(CMOD): $(OBJS) + $(LD) $(LDFLAGS) $(LIBDIR) $(OBJS) $(LIBS) -o $@ diff --git a/lib/idna/idna.c b/lib/idna/idna.c new file mode 100644 index 0000000..66b85b5 --- /dev/null +++ b/lib/idna/idna.c @@ -0,0 +1,92 @@ +/* + * Copyright 2014 Masatoshi Teruya. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + * idna.c + * lua-idna + * + * Created by Masatoshi Teruya on 14/12/06. + * + */ + +#include +#include +#include +#include +#include + +#define lstate_fn2tbl(L,k,v) do{ \ + lua_pushstring(L,k); \ + lua_pushcfunction(L,v); \ + lua_rawset(L,-3); \ +}while(0) + +#define pdealloc(p) free((void*)p) + +static int encode_lua( lua_State *L ) +{ + const char *src = luaL_checkstring( L, 1 ); + char *dest = NULL; + int rc = idna_to_ascii_8z( src, &dest, 0 ); + + if( rc == IDNA_SUCCESS ){ + lua_pushstring( L, dest ); + pdealloc( dest ); + return 1; + } + + // got error + lua_pushnil( L ); + lua_pushstring( L, idna_strerror( rc ) ); + + return 2; +} + + +static int decode_lua( lua_State *L ) +{ + const char *src = luaL_checkstring( L, 1 ); + char *dest = NULL; + int rc = idna_to_unicode_8z8z( src, &dest, 0 ); + + if( rc == IDNA_SUCCESS ){ + lua_pushstring( L, dest ); + pdealloc( dest ); + return 1; + } + + // got error + lua_pushnil( L ); + lua_pushstring( L, idna_strerror( rc ) ); + + return 2; +} + + +LUALIB_API int luaopen_idna( lua_State *L ) +{ + lua_createtable( L, 0, 2 ); + lstate_fn2tbl( L, "encode", encode_lua ); + lstate_fn2tbl( L, "decode", decode_lua ); + + return 1; +} + + diff --git a/lib/net/url.lua b/lib/net/url.lua new file mode 100644 index 0000000..78c66b4 --- /dev/null +++ b/lib/net/url.lua @@ -0,0 +1,443 @@ +-- neturl.lua - a robust url parser and builder +-- +-- Bertrand Mansion, 2011-2013; License MIT +-- @module neturl +-- @alias M + +local M = {} +M.version = "0.9.0" + +--- url options +-- separator is set to `&` by default but could be anything like `&amp;` or `;` +-- @todo Add an option to limit the size of the argument table +M.options = { + separator = '&' +} + +--- list of known and common scheme ports +-- as documented in IANA URI scheme list +M.services = { + acap = 674, + cap = 1026, + dict = 2628, + ftp = 21, + gopher = 70, + http = 80, + https = 443, + iax = 4569, + icap = 1344, + imap = 143, + ipp = 631, + ldap = 389, + mtqp = 1038, + mupdate = 3905, + news = 2009, + nfs = 2049, + nntp = 119, + rtsp = 554, + sip = 5060, + snmp = 161, + telnet = 23, + tftp = 69, + vemmi = 575, + afs = 1483, + jms = 5673, + rsync = 873, + prospero = 191, + videotex = 516 +} + +local legal = { + ["-"] = true, ["_"] = true, ["."] = true, ["!"] = true, + ["~"] = true, ["*"] = true, ["'"] = true, ["("] = true, + [")"] = true, [":"] = true, ["@"] = true, ["&"] = true, + ["="] = true, ["+"] = true, ["$"] = true, [","] = true, + [";"] = true -- can be used for parameters in path +} + +local function decode(str) + local str = str:gsub('+', ' ') + return (str:gsub("%%(%x%x)", function(c) + return string.char(tonumber(c, 16)) + end)) +end + +local function encode(str) + return (str:gsub("([^A-Za-z0-9%_%.%-%~])", function(v) + return string.upper(string.format("%%%02x", string.byte(v))) + end)) +end + +-- for query values, prefer + instead of %20 for spaces +local function encodeValue(str) + local str = encode(str) + return str:gsub('%%20', '+') +end + +local function encodeSegment(s) + local legalEncode = function(c) + if legal[c] then + return c + end + return encode(c) + end + return s:gsub('([^a-zA-Z0-9])', legalEncode) +end + +--- builds the url +-- @return a string representing the built url +function M:build() + local url = '' + if self.path then + local path = self.path + path:gsub("([^/]+)", function (s) return encodeSegment(s) end) + url = url .. tostring(path) + end + if self.query then + local qstring = tostring(self.query) + if qstring ~= "" then + url = url .. '?' .. qstring + end + end + if self.host then + local authority = self.host + if self.port and self.scheme and M.services[self.scheme] ~= self.port then + authority = authority .. ':' .. self.port + end + local userinfo + if self.user and self.user ~= "" then + userinfo = self.user + if self.password then + userinfo = userinfo .. ':' .. self.password + end + end + if userinfo and userinfo ~= "" then + authority = userinfo .. '@' .. authority + end + if authority then + if url ~= "" then + url = '//' .. authority .. '/' .. url:gsub('^/+', '') + else + url = '//' .. authority + end + end + end + if self.scheme then + url = self.scheme .. ':' .. url + end + if self.fragment then + url = url .. '#' .. self.fragment + end + return url +end + +--- builds the querystring +-- @param tab The key/value parameters +-- @param sep The separator to use (optional) +-- @param key The parent key if the value is multi-dimensional (optional) +-- @return a string representing the built querystring +function M.buildQuery(tab, sep, key) + local query = {} + if not sep then + sep = M.options.separator or '&' + end + local keys = {} + for k in pairs(tab) do + keys[#keys+1] = k + end + table.sort(keys) + for _,name in ipairs(keys) do + local value = tab[name] + name = encode(tostring(name)) + if key then + name = string.format('%s[%s]', tostring(key), tostring(name)) + end + if type(value) == 'table' then + query[#query+1] = M.buildQuery(value, sep, name) + else + local value = encodeValue(tostring(value)) + if value ~= "" then + query[#query+1] = string.format('%s=%s', name, value) + else + query[#query+1] = name + end + end + end + return table.concat(query, sep) +end + +--- Parses the querystring to a table +-- This function can parse multidimensional pairs and is mostly compatible +-- with PHP usage of brackets in key names like ?param[key]=value +-- @param str The querystring to parse +-- @param sep The separator between key/value pairs, defaults to `&` +-- @todo limit the max number of parameters with M.options.max_parameters +-- @return a table representing the query key/value pairs +function M.parseQuery(str, sep) + if not sep then + sep = M.options.separator or '&' + end + + local values = {} + for key,val in str:gmatch(string.format('([^%q=]+)(=*[^%q=]*)', sep, sep)) do + local key = decode(key) + local keys = {} + key = key:gsub('%[([^%]]*)%]', function(v) + -- extract keys between balanced brackets + if string.find(v, "^-?%d+$") then + v = tonumber(v) + else + v = decode(v) + end + table.insert(keys, v) + return "=" + end) + key = key:gsub('=+.*$', "") + key = key:gsub('%s', "_") -- remove spaces in parameter name + val = val:gsub('^=+', "") + + if not values[key] then + values[key] = {} + end + if #keys > 0 and type(values[key]) ~= 'table' then + values[key] = {} + elseif #keys == 0 and type(values[key]) == 'table' then + values[key] = decode(val) + end + + local t = values[key] + for i,k in ipairs(keys) do + if type(t) ~= 'table' then + t = {} + end + if k == "" then + k = #t+1 + end + if not t[k] then + t[k] = {} + end + if i == #keys then + t[k] = decode(val) + end + t = t[k] + end + end + setmetatable(values, { __tostring = M.buildQuery }) + return values +end + +--- set the url query +-- @param query Can be a string to parse or a table of key/value pairs +-- @return a table representing the query key/value pairs +function M:setQuery(query) + local query = query + if type(query) == 'table' then + query = M.buildQuery(query) + end + self.query = M.parseQuery(query) + return query +end + +--- set the authority part of the url +-- The authority is parsed to find the user, password, port and host if available. +-- @param authority The string representing the authority +-- @return a string with what remains after the authority was parsed +function M:setAuthority(authority) + self.authority = authority + self.port = nil + self.host = nil + self.userinfo = nil + self.user = nil + self.password = nil + + authority = authority:gsub('^([^@]*)@', function(v) + self.userinfo = v + return '' + end) + authority = authority:gsub("^%[[^%]]+%]", function(v) + -- ipv6 + self.host = v + return '' + end) + authority = authority:gsub(':([^:]*)$', function(v) + self.port = tonumber(v) + return '' + end) + if authority ~= '' and not self.host then + self.host = authority:lower() + end + if self.userinfo then + local userinfo = self.userinfo + userinfo = userinfo:gsub(':([^:]*)$', function(v) + self.password = v + return '' + end) + self.user = userinfo + end + return authority +end + +--- Parse the url into the designated parts. +-- Depending on the url, the following parts can be available: +-- scheme, userinfo, user, password, authority, host, port, path, +-- query, fragment +-- @param url Url string +-- @return a table with the different parts and a few other functions +function M.parse(url) + local comp = {} + M.setAuthority(comp, "") + M.setQuery(comp, "") + + local url = tostring(url or '') + url = url:gsub('#(.*)$', function(v) + comp.fragment = v + return '' + end) + url =url:gsub('^([%w][%w%+%-%.]*)%:', function(v) + comp.scheme = v:lower() + return '' + end) + url = url:gsub('%?(.*)', function(v) + M.setQuery(comp, v) + return '' + end) + url = url:gsub('^//([^/]*)', function(v) + M.setAuthority(comp, v) + return '' + end) + comp.path = decode(url) + + setmetatable(comp, { + __index = M, + __tostring = M.build} + ) + return comp +end + +--- removes dots and slashes in urls when possible +-- This function will also remove multiple slashes +-- @param path The string representing the path to clean +-- @return a string of the path without unnecessary dots and segments +function M.removeDotSegments(path) + local fields = {} + if string.len(path) == 0 then + return "" + end + local startslash = false + local endslash = false + if string.sub(path, 1, 1) == "/" then + startslash = true + end + if (string.len(path) > 1 or startslash == false) and string.sub(path, -1) == "/" then + endslash = true + end + + path:gsub('[^/]+', function(c) table.insert(fields, c) end) + + local new = {} + local j = 0 + + for i,c in ipairs(fields) do + if c == '..' then + if j > 0 then + j = j - 1 + end + elseif c ~= "." then + j = j + 1 + new[j] = c + end + end + local ret = "" + if #new > 0 and j > 0 then + ret = table.concat(new, '/', 1, j) + else + ret = "" + end + if startslash then + ret = '/'..ret + end + if endslash then + ret = ret..'/' + end + return ret +end + +local function absolutePath(base_path, relative_path) + if string.sub(relative_path, 1, 1) == "/" then + return '/' .. string.gsub(relative_path, '^[%./]+', '') + end + local path = base_path + if relative_path ~= "" then + path = '/'..path:gsub("[^/]*$", "") + end + path = path .. relative_path + path = path:gsub("([^/]*%./)", function (s) + if s ~= "./" then return s else return "" end + end) + path = string.gsub(path, "/%.$", "/") + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, "([^/]*/%.%./)", function (s) + if s ~= "../../" then return "" else return s end + end) + end + path = string.gsub(path, "([^/]*/%.%.?)$", function (s) + if s ~= "../.." then return "" else return s end + end) + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, '^/?%.%./', '') + end + return '/' .. path +end + +--- builds a new url by using the one given as parameter and resolving paths +-- @param other A string or a table representing a url +-- @return a new url table +function M:resolve(other) + if type(self) == "string" then + self = M.parse(self) + end + if type(other) == "string" then + other = M.parse(other) + end + if other.scheme then + return other + else + other.scheme = self.scheme + if not other.authority or other.authority == "" then + other:setAuthority(self.authority) + if not other.path or other.path == "" then + other.path = self.path + local query = other.query + if not query or not next(query) then + other.query = self.query + end + else + other.path = absolutePath(self.path, other.path) + end + end + return other + end +end + +--- normalize a url path following some common normalization rules +-- described on The URL normalization page of Wikipedia +-- @return the normalized path +function M:normalize() + if type(self) == 'string' then + self = M.parse(self) + end + if self.path then + local path = self.path + path = absolutePath(path, "") + -- normalize multiple slashes + path = string.gsub(path, "//+", "/") + self.path = path + end + return self +end + +return M \ No newline at end of file diff --git a/lib/resty/iputils.lua b/lib/resty/iputils.lua new file mode 100644 index 0000000..ef71097 --- /dev/null +++ b/lib/resty/iputils.lua @@ -0,0 +1,207 @@ +local ipairs, tonumber, tostring, type = ipairs, tonumber, tostring, type +local bit = require("bit") +local tobit = bit.tobit +local lshift = bit.lshift +local band = bit.band +local bor = bit.bor +local xor = bit.bxor +local byte = string.byte +local str_find = string.find +local str_sub = string.sub + +local lrucache = nil + +local _M = { + _VERSION = '0.2.1', +} + +local mt = { __index = _M } + + +-- Precompute binary subnet masks... +local bin_masks = {} +for i=1,32 do + bin_masks[tostring(i)] = lshift(tobit((2^i)-1), 32-i) +end +-- ... and their inverted counterparts +local bin_inverted_masks = {} +for i=1,32 do + local i = tostring(i) + bin_inverted_masks[i] = xor(bin_masks[i], bin_masks["32"]) +end + +local log_err +if ngx then + log_err = function(...) + ngx.log(ngx.ERR, ...) + end +else + log_err = function(...) + print(...) + end +end + + +local function enable_lrucache(size) + local size = size or 4000 -- Cache the last 4000 IPs (~1MB memory) by default + local lrucache_obj, err = require("resty.lrucache").new(size) + if not lrucache_obj then + return nil, "failed to create the cache: " .. (err or "unknown") + end + lrucache = lrucache_obj + return true +end +_M.enable_lrucache = enable_lrucache + + +local function split_octets(input) + local pos = 0 + local prev = 0 + local octs = {} + + for i=1, 4 do + pos = str_find(input, ".", prev, true) + if pos then + if i == 4 then + -- Should not have a match after 4 octets + return nil, "Invalid IP" + end + octs[i] = str_sub(input, prev, pos-1) + elseif i == 4 then + -- Last octet, get everything to the end + octs[i] = str_sub(input, prev, -1) + break + else + return nil, "Invalid IP" + end + prev = pos +1 + end + + return octs +end + + +local function ip2bin(ip) + if lrucache then + local get = lrucache:get(ip) + if get then + return get[1], get[2] + end + end + + if type(ip) ~= "string" then + return nil, "IP must be a string" + end + + local octets = split_octets(ip) + if not octets or #octets ~= 4 then + return nil, "Invalid IP" + end + + -- Return the binary representation of an IP and a table of binary octets + local bin_octets = {} + local bin_ip = 0 + + for i,octet in ipairs(octets) do + local bin_octet = tonumber(octet) + if not bin_octet or bin_octet > 255 then + return nil, "Invalid octet: "..tostring(octet) + end + bin_octet = tobit(bin_octet) + bin_octets[i] = bin_octet + bin_ip = bor(lshift(bin_octet, 8*(4-i) ), bin_ip) + end + + if lrucache then + lrucache:set(ip, {bin_ip, bin_octets}) + end + return bin_ip, bin_octets +end +_M.ip2bin = ip2bin + + +local function split_cidr(input) + local pos = str_find(input, "/", 0, true) + if not pos then + return {input} + end + return {str_sub(input, 1, pos-1), str_sub(input, pos+1, -1)} +end + + +local function parse_cidr(cidr) + local mask_split = split_cidr(cidr, '/') + local net = mask_split[1] + local mask = mask_split[2] or "32" + local mask_num = tonumber(mask) + if not mask_num or (mask_num > 32 or mask_num < 1) then + return nil, "Invalid prefix: /"..tostring(mask) + end + + local bin_net, err = ip2bin(net) -- Convert IP to binary + if not bin_net then + return nil, err + end + local bin_mask = bin_masks[mask] -- Get masks + local bin_inv_mask = bin_inverted_masks[mask] + + local lower = band(bin_net, bin_mask) -- Network address + local upper = bor(lower, bin_inv_mask) -- Broadcast address + return lower, upper +end +_M.parse_cidr = parse_cidr + + +local function parse_cidrs(cidrs) + local out = {} + local i = 1 + for _,cidr in ipairs(cidrs) do + local lower, upper = parse_cidr(cidr) + if not lower then + log_err("Error parsing '", cidr, "': ", upper) + else + out[i] = {lower, upper} + i = i+1 + end + end + return out +end +_M.parse_cidrs = parse_cidrs + + +local function ip_in_cidrs(ip, cidrs) + local bin_ip, bin_octets = ip2bin(ip) + if not bin_ip then + return nil, bin_octets + end + + for _,cidr in ipairs(cidrs) do + if bin_ip >= cidr[1] and bin_ip <= cidr[2] then + return true + end + end + return false +end +_M.ip_in_cidrs = ip_in_cidrs + + +local function binip_in_cidrs(bin_ip_ngx, cidrs) + if 4 ~= #bin_ip_ngx then + return false, "invalid IP address" + end + + local bin_ip = 0 + for i=1,4 do + bin_ip = bor(lshift(bin_ip, 8), tobit(byte(bin_ip_ngx, i))) + end + + for _,cidr in ipairs(cidrs) do + if bin_ip >= cidr[1] and bin_ip <= cidr[2] then + return true + end + end + return false +end +_M.binip_in_cidrs = binip_in_cidrs + +return _M diff --git a/ourl/config.sample.lua b/ourl/config.sample.lua index 3c66e94..777c4a2 100644 --- a/ourl/config.sample.lua +++ b/ourl/config.sample.lua @@ -26,6 +26,13 @@ _M.hash.salt = 'ourl' _M.hash.length = 5 _M.hash.alphabet = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890' +_M.proxies = { + '127.0.0.0/8', + '10.0.0.0/8', + '172.16.0.0/12', + '192.168.0.0/16', +} + _M.db.timeout = 1000 _M.db.keepalive = 10000 _M.db.poolsize = 10 diff --git a/ourl/init.lua b/ourl/init.lua index a117852..2325c3a 100644 --- a/ourl/init.lua +++ b/ourl/init.lua @@ -1,6 +1,10 @@ local config = require 'ourl.config' local router = require 'lib.router' local hashid = require 'lib.hashids.init' +local neturl = require 'lib.net.url' +-- idna.so 必须在 lua_package_cpath 根目录下才能使用 +local idna = require 'idna' +local iputil = require 'lib.resty.iputils' local json = require 'cjson' local mysql = require 'resty.mysql' local r_sha1 = require 'resty.sha1' @@ -13,10 +17,10 @@ end local STATUS_ERR = 0 local STATUS_OK = 1 -local r, h, db_rw, db_ro +local r, h, db_rw, db_ro, base, proxy_whitelist local function log(...) - ngx.log(config.log_level, ...) + ngx.log(config.log_level, json.encode({...})) end local function finish() @@ -48,16 +52,16 @@ end local function db_query(db, query) local ok, err = db:send_query(query) if not ok then - log('failed to send query: ', query, ' : ', err) + log('failed to send query', query, err) die('数据库错误') end local res, err, errcode, sqlstate = db:read_result() if not res then - log('failed to read result of query: ', query, errcode, ': ', err, sqlstate) + log('failed to read result of query', query, errcode, err, sqlstate) die('数据库错误') elseif config.debug then - log('[DEBUG]: ' .. json.encode({query, res, errcode, err, sqlstate})) + log('[DEBUG]', query, res, errcode, err, sqlstate) end return res, err, errcode, sqlstate @@ -65,35 +69,68 @@ end local function ip2long(ip) local l = 0 - for v in ip:gmatch([=[[^\.]+]=]) do - l = l * 256 + v + for v in ngx.re.gmatch(ip, [=[[^\.]+]=], 'o') do -- 这条注释是为了修复文本编辑器对 lua 语法的 bug ,如果你看到了,说明这人忘记删了]] + l = l * 256 + tonumber(v[0]) end return l end local function real_remote_addr() - --Todo: 判断代理 - return ip2long(ngx.var.remote_addr) + local ip = ngx.var.remote_addr + local proxy = ngx.req.get_headers()['X-Forwarded-For'] + if proxy then + if 'table' == type(proxy) then + proxy = proxy[1] + end + local pattern = [=[([0-9]{1,3}\.){3}[0-9]{1,3}]=] + local m = ngx.re.match(proxy, pattern, 'o') + if m then + proxy = m[0] + if iputil.ip_in_cidrs(ip, proxy_whitelist) then + ip = proxy + end + end + end + return ip2long(ip) end -local function test(params) - for k, v in pairs(params) do - if '' == type(v) then - ngx.say(k, ': ', table.concat(v, ', ')) - else - ngx.say(k, ': ', v) +local function url_modify(url, scheme) + scheme = scheme or 'http' + url = neturl.parse(url) + if not url.host then + return nil + end + local ok, err = idna.encode(url.host) + if ok then + url.host = ok + else + if config.debug then + log('[DEBUG]', url.host, err) end + die('非法域名') end + if not url.scheme then + url.scheme = scheme + end + return tostring(url:normalize()) end local function shorten(params) local url = params.url if not url then - die('请传入正确的 url') + die('请输入正确的 url') end if 'table' == type(url) then url = url[1] end + url = url_modify(url) + if not url then + die('请输入正确的 url') + end + local pattern = ("^https?://%s/"):format(ngx.var.host) + if ngx.re.match(url, pattern, 'o') then + die('该地址不能被缩短') + end local sha1 = r_sha1:new() sha1:update(url) local digest = r_str.to_hex(sha1:final()) @@ -120,7 +157,7 @@ local function shorten(params) local s = h:encode(id) json_api({ status = STATUS_OK, - s_url = ngx.var.scheme .. [[://]] .. ngx.var.host .. [[/]] .. s + s_url = base .. s }) end @@ -132,7 +169,11 @@ local function expand(params) if 'table' == type(s) then s = s[1] end - local m = ngx.re.match(s, [[^https?://]] .. ngx.var.host .. [=[/([]=] .. config.hash.alphabet .. [=[]{]=] .. config.hash.length .. [=[})]=]) + local pattern = ("^https?://%s/([%s]+)$"):format( + ngx.var.host, + config.hash.alphabet + ) + local m = ngx.re.match(s, pattern, 'o') if m then s = m[1] local id = h:decode(s) @@ -185,19 +226,19 @@ local function prepare() local err db_rw, err = mysql:new() if not db_rw then - log('failed to init mysql master: ', err) + log('failed to init mysql master', err) die('数据库错误') end db_ro, err = mysql:new() if not db_ro then - log('failed to init mysql slave: ', err) + log('failed to init mysql slave', err) die('数据库错误') end for name, db in pairs({db_rw = db_rw, db_ro = db_ro}) do db:set_timeout(config.db.timeout) local ok, err, errcode, sqlstate = db:connect(config[name]) if not ok then - log('failed to connect to mysql: ', name, err, errcode, sqlstate) + log('failed to connect to mysql', name, errcode, err, sqlstate) die('数据库错误') end local count @@ -205,14 +246,13 @@ local function prepare() if 0 == count then db_query(db, [[SET NAMES 'utf8';]]) elseif err then - log('failed to get reused times: ', name, err) + log('failed to get reused times', name, err) die('数据库错误') end end if not r then r = router.new() - r:get('/test', test) r:get('/shorten', shorten) r:get('/expand', expand) r:get('/:hash', redirect) @@ -220,6 +260,12 @@ local function prepare() if not h then h = hashid.new(config.hash.salt, config.hash.length, config.hash.alphabet) end + if not base then + base = ("%s://%s/"):format(ngx.var.scheme, ngx.var.host) + end + if not proxy_whitelist then + proxy_whitelist = iputil.parse_cidrs(config.proxies) + end ngx.req.read_body() end @@ -233,7 +279,7 @@ function _M.run() ) if not ok then if config.debug then - die('服务器错误' .. err) + die('服务器错误: ' .. err) else finish() ngx.status = ngx.HTTP_NOT_FOUND