-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhttp-replicator
executable file
·154 lines (139 loc) · 6.57 KB
/
http-replicator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#! /usr/bin/env python
import sys
assert sys.version_info >= (3, 6) #notably: f-strings, asyncio, aiohttp
import asyncio, hashlib, logging, os, weakref
from ipaddress import ip_address
from aiohttp import web
from replicator.Params import OPTS
from replicator.Cache import Cache
from replicator.FtpProtocol import FtpProtocol
from replicator.HttpProtocol import HttpProtocol, blind_transfer
from replicator.Utils import daemonize, header_summary
DOWNLOADS = weakref.WeakValueDictionary()
class InboundRequest:
def __init__(self, request):
#extract the "interesting" parts of the request, and decide which proto will handle request
self.method, self.content = request.method.upper(), request.content
self.url = request.url
self.path = request.headers.get('x-unique-cache-name', request.url.path)
if self.url.scheme == 'http':
self.host, self.port = self.url.host, self.url.port
self.proto = HttpProtocol if self.method == 'GET' else None
elif self.url.scheme == '': #assumed to be a transparent http proxy request
self.host, self.port = request.headers['host'], 80
if ':' in self.host:
self.host, self.port = self.host.rsplit(':')
self.proto = HttpProtocol if self.method == 'GET' else None
elif self.url.scheme == 'ftp':
assert self.method == 'GET', f'{self.method} request unsupported for ftp'
self.port = self.url.port or 21
self.proto, self.host = FtpProtocol, self.url.host
else:
raise AssertionError(f'invalid url: {self.url}')
assert self.host.find('/') == -1, f'Request for invalid host name: {self.host}'
assert 0 < int(self.port) < 65536, f'Request for invalid port number: {self.port}'
self.cacheid = f'{self.host}:{self.port}{self._normalize_path(self.path)}'
if OPTS.flat:
self.cacheid = os.path.basename(self.cacheid)
self.range = request.http_range or (0, None)
self.headers = request.headers.copy()
self.headers.update({'host': self.host})
for k in 'proxy-connection', 'proxy-authorization', 'keep-alive':
self.headers.pop(k, None) #remove headers we don't want to propagate upstream
self.header_summary = header_summary(request.headers, heading='Request headers:')
logging.debug('%s', self.header_summary)
def _normalize_path(self, origpath):
path = os.sep + origpath
sep = path.find('?')
if sep != -1:
path = path[:sep] + path[sep:].replace('/', '%2F')
if path[-13:] == '.__download__': #gentoo adds this to x-unique-cache-name
path = path[:-13]
path = os.path.normpath(path)
maxlen = OPTS.maxfilelen
if 0 < maxlen:
path_parts = []
idxlimit = max(30, maxlen - 42) #2 for '..' + 40 for hexdigest
for item in path.split(os.sep):
if maxlen < len(item):
itemhash = hashlib.shake_128(item.encode())
item = item[:idxlimit] + '..' + itemhash.hexdigest(20)
path_parts.append(item)
newpath = os.sep.join(path_parts)
if newpath != path:
logging.info('Shortened path to %s characters',
'/'.join(str(len(w)) for w in path_parts))
path = newpath
for pair in OPTS.aliasmap:
if os.path.commonprefix((path, pair[0])) == pair[0]:
path = pair[1] + path[len(pair[0]):]
break #stoping after first match seems the least surprising way to do this
return path
def allowed_remote(ip, cidr_list):
ip = ip_address(ip)
for cidr in cidr_list:
if ip in cidr:
return True
return False
async def serve_request(downstream): #downstream is aiohttp.web.[Base]Request
rhost, rport = downstream.transport.get_extra_info('peername')[:2]
logging.debug('')
if OPTS.allowed_CIDRs and not allowed_remote(rhost, OPTS.allowed_CIDRs):
logging.info('Rejecting request from [%s]:%d due to --ip restriction', rhost, rport)
response = web.Response(status=403)
await response.prepare(downstream)
await response.write(f'access from {rhost} is prohibited'.encode())
return
logging.info('Accepted request from [%s]:%d for %s', rhost, rport, downstream.url.human_repr())
inquest = InboundRequest(downstream)
try:
if inquest.proto is None:
await blind_transfer(inquest, web.StreamResponse(), downstream)
return
futs = []
cache = DOWNLOADS.get(inquest.cacheid, None)
if cache: #re-use cache entry of an active download, if available
logging.debug('Joined running download')
else:
cache = DOWNLOADS[inquest.cacheid] = Cache(inquest.cacheid)
futs.append(asyncio.ensure_future(cache.writer(inquest.proto(inquest))))
outresp = web.StreamResponse()
rtask = cache.reader(outresp, downstream, inquest.range.start, inquest.range.stop)
futs.append(asyncio.ensure_future(rtask))
await asyncio.gather(*futs)
except Exception as msg:
show_backtrace = not isinstance(msg, AssertionError)
logging.warning('Error: %s', msg, exc_info=show_backtrace)
logging.warning('%s', inquest.header_summary)
blen, chunk = 0, True
while chunk:
chunk = await downstream.content.read(OPTS.maxchunk)
blen += len(chunk)
if blen:
logging.warning('+ Body of %d bytes', blen)
response = web.Response(content_type='text/plain')
response.set_status(500, 'Internal Server Error')
await response.prepare(downstream)
await response.write(inquest.header_summary.encode())
async def startup():
runner = web.ServerRunner(web.Server(serve_request), access_log=None)
await runner.setup()
try:
for addr in OPTS.bind:
site = web.TCPSite(runner, addr, OPTS.port)
asyncio.ensure_future(site.start())
except Exception as e:
sys.exit(f'error: failed to create socket: {e}')
# main - setup aiohttp and run its event loop:
loop = asyncio.get_event_loop()
loop.run_until_complete(startup())
daemonize() #note that this has been deferred until all startup has completed successfully
try:
logging.info('Replicator started at %s, port %d', OPTS.bind, OPTS.port)
loop.run_forever()
except KeyboardInterrupt:
logging.info('Replicator terminated')
sys.exit(0)
except:
logging.exception('Replicator crashed')
sys.exit(f'Replicator crashed')