Skip to content

Commit

Permalink
Add support for authenticated media (#290)
Browse files Browse the repository at this point in the history
Setup instructions:

1. Set up a reverse proxy to pass `/_heisenbridge/media/*` to heisenbridge
2. Configure `heisenbridge` -> `media_url` in the registration file with the public URL that the reverse proxy handles

Optionally, you can run another heisenbridge instance with the `--media-proxy` flag to have it in a separate process
  • Loading branch information
tulir authored Aug 9, 2024
1 parent 498d79c commit b4da6e5
Showing 1 changed file with 89 additions and 22 deletions.
111 changes: 89 additions & 22 deletions heisenbridge/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
import asyncio
import base64
import grp
import hashlib
import hmac
import logging
import os
import pwd
Expand All @@ -14,6 +17,7 @@
from typing import List
from typing import Tuple

from aiohttp import web
from mautrix.api import HTTPAPI
from mautrix.api import Method
from mautrix.api import Path
Expand Down Expand Up @@ -59,7 +63,7 @@ class BridgeAppService(AppService):
_rooms: Dict[str, Room]
_users: Dict[str, str]

DEFAULT_MEDIA_PATH = "/_matrix/media/v3/download/{netloc}{path}{filename}"
DEFAULT_MEDIA_PATH = "/_heisenbridge/media/{server}/{media_id}/{checksum}{filename}"

async def push_bridge_state(
self,
Expand Down Expand Up @@ -332,17 +336,70 @@ async def detect_public_endpoint(self):
logging.warning("Using internal URL for homeserver, media links are likely broken!")
return str(self.api.base_url)

def mxc_to_url(self, mxc, filename=None):
mxc = urllib.parse.urlparse(mxc)
def mxc_checksum(self, server: str, media_id: str) -> str:
# Add trailing slash to prevent length extension attacks
checksum_raw = hmac.new(self.media_key, f"mxc://{server}/{media_id}/".encode("utf-8"), hashlib.sha256).digest()
return base64.urlsafe_b64encode(checksum_raw[:8]).decode("utf-8").rstrip("=")

async def proxy_media(self, req: web.Request) -> web.StreamResponse | web.Response:
server = req.match_info["server"]
media_id = req.match_info["media_id"]
checksum = req.match_info["checksum"]
if self.mxc_checksum(server, media_id) != checksum:
return web.Response(status=403, text="Invalid checksum")
download_url = self.api.base_url / "_matrix/client/v1/media/download" / server / media_id
filename = req.match_info.get("filename", "")
if filename:
download_url /= filename
query_params: dict[str, str] = {"allow_redirect": "true", "user_id": self.az.bot_mxid}
headers: dict[str, str] = {"Authorization": f"Bearer {self.az.as_token}"}
resp_headers = {
"Content-Security-Policy": (
"sandbox; default-src 'none'; script-src 'none'; style-src 'none'; object-src 'none';"
),
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, HEAD, OPTIONS",
"Content-Disposition": "attachment",
}
started_writing = False
try:
async with self.api.session.get(download_url, params=query_params, headers=headers) as dl_resp:
resp = web.StreamResponse(status=dl_resp.status, headers=resp_headers)
if dl_resp.content_length:
resp.content_length = dl_resp.content_length
resp.content_type = dl_resp.content_type
if "Content-Disposition" in dl_resp.headers:
resp.headers["Content-Disposition"] = dl_resp.headers["Content-Disposition"]
elif resp.status >= 300:
del resp.headers["Content-Disposition"]
started_writing = True
await resp.prepare(req)
async for chunk, end_of_chunk in dl_resp.content.iter_chunks():
await resp.write(chunk)
return resp
except Exception:
if not started_writing:
logging.exception("Failed to fetch media")
return web.Response(status=502, text="Failed to fetch media")

def mxc_to_url(self, mxc: str, filename=None):
if not self.media_endpoint:
return "<media unavailable>"
try:
server, media_id = self.api.parse_mxc_uri(mxc)
except ValueError:
return "<invalid mxc URI>"

if filename is None:
filename = ""
else:
filename = "/" + urllib.parse.quote(filename)

media_path = self.media_path.format(netloc=mxc.netloc, path=mxc.path, filename=filename)
media_path = self.media_path.format(
server=server, media_id=media_id, checksum=self.mxc_checksum(server, media_id), filename=filename
)

return "{}{}".format(self.endpoint, media_path)
return "{}{}".format(self.media_endpoint, media_path)

async def reset(self, config_file, homeserver_url):
with open(config_file) as f:
Expand Down Expand Up @@ -448,7 +505,7 @@ async def ensure_hidden_room(self):

return use_hidden_room

async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mode):
async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mode, media_proxy):
if "sender_localpart" not in self.registration:
print("Missing sender_localpart from registration file.")
sys.exit(1)
Expand Down Expand Up @@ -485,6 +542,8 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod
print(f"Heisenbridge v{__version__}", flush=True)
if safe_mode:
print("Safe mode is enabled.", flush=True)
if media_proxy:
print("Media proxy only mode.", flush=True)

url = urllib.parse.urlparse(homeserver_url)
ws = None
Expand Down Expand Up @@ -542,6 +601,8 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod
state_store=MemoryBridgeStateStore(),
)
self.az.matrix_event_handler(self._on_mx_event)
self.az.app.router.add_get("/_heisenbridge/media/{server}/{media_id}/{checksum}/{filename}", self.proxy_media)
self.az.app.router.add_get("/_heisenbridge/media/{server}/{media_id}/{checksum}", self.proxy_media)

try:
await self.az.start(host=listen_address, port=listen_port)
Expand Down Expand Up @@ -578,6 +639,7 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod
"use_reacts": True,
"media_url": None,
"media_path": None,
"media_key": None,
"namespace": self.puppet_prefix,
}
logging.debug(f"Default config: {self.config}")
Expand All @@ -594,27 +656,21 @@ async def run(self, listen_address, listen_port, homeserver_url, owner, safe_mod
# load config from HS
await self.load()

async def _resolve_media_endpoint():
endpoint = await self.detect_public_endpoint()

# only rewrite it if it wasn't changed
if self.endpoint == str(self.api.base_url):
self.endpoint = endpoint

print("Homeserver is publicly available at " + self.endpoint, flush=True)
if "heisenbridge" in self.registration and "media_key" in self.registration["heisenbridge"]:
self.media_key = self.registration["heisenbridge"]["media_key"].encode("utf-8")
elif self.config["media_key"]:
self.media_key = self.config["media_key"].encode("utf-8")
else:
self.media_key = self.registration["hs_token"].encode("utf-8")

# use configured media_url for endpoint if we have it
if "heisenbridge" in self.registration and "media_url" in self.registration["heisenbridge"]:
logging.debug(
f"Overriding media URL from registration file to {self.registration['heisenbridge']['media_url']}"
)
self.endpoint = self.registration["heisenbridge"]["media_url"]
self.media_endpoint = self.registration["heisenbridge"]["media_url"]
elif self.config["media_url"]:
self.endpoint = self.config["media_url"]
else:
print("Trying to detect homeserver public endpoint, this might take a while...", flush=True)
self.endpoint = str(self.api.base_url)
asyncio.ensure_future(_resolve_media_endpoint())
self.media_endpoint = self.config["media_url"]

# use configured media_path for media_path if we have it
if "heisenbridge" in self.registration and "media_path" in self.registration["heisenbridge"]:
Expand All @@ -627,6 +683,11 @@ async def _resolve_media_endpoint():
else:
self.media_path = self.DEFAULT_MEDIA_PATH

if media_proxy:
logging.info("Media proxy mode startup complete")
await asyncio.Event().wait()
return

logging.info("Starting presence loop")
self._keepalive()

Expand Down Expand Up @@ -854,6 +915,12 @@ async def async_main():
help="reset ALL bridge configuration from homeserver and exit",
default=argparse.SUPPRESS,
)
parser.add_argument(
"--media-proxy",
action="store_true",
help="run in media proxy mode",
default=False,
)
parser.add_argument(
"--safe-mode",
action="store_true",
Expand Down Expand Up @@ -924,7 +991,7 @@ async def async_main():

service.load_reg(args.config)

if args.identd:
if args.identd and not args.media_proxy:
identd = Identd()
await identd.start_listening(service, args.identd_port)

Expand Down Expand Up @@ -963,7 +1030,7 @@ async def async_main():
except Exception:
pass

await service.run(listen_address, listen_port, args.homeserver, args.owner, args.safe_mode)
await service.run(listen_address, listen_port, args.homeserver, args.owner, args.safe_mode, args.media_proxy)


def main():
Expand Down

0 comments on commit b4da6e5

Please sign in to comment.