Skip to content

Commit

Permalink
Improve reload experience when using with site generator such as cargo
Browse files Browse the repository at this point in the history
doc
  • Loading branch information
Pistonight committed Jun 4, 2024
1 parent bcda1df commit f6e77f8
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 36 deletions.
14 changes: 14 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use watcher::create_watcher;

static ADDR: OnceCell<String> = OnceCell::const_new();
static ROOT: OnceCell<PathBuf> = OnceCell::const_new();
static HARD: OnceCell<bool> = OnceCell::const_new();
static TX: OnceCell<broadcast::Sender<()>> = OnceCell::const_new();

pub struct Listener {
Expand Down Expand Up @@ -120,3 +121,16 @@ pub async fn listen<A: Into<String>, R: Into<PathBuf>>(
rx,
})
}

/// Configure live-server to always hard reload the page instead of hot-reload
/// ```
/// use live_server::{listen, hard_reload};
///
/// async fn serve_hard() -> Result<(), Box<dyn std::error::Error>> {
/// hard_reload();
/// listen("127.0.0.1:8080", "./").await?.start().await
/// }
/// ```
pub fn hard_reload() {
let _ = HARD.set(true);
}
12 changes: 11 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use clap::Parser;
use env_logger::Env;
use live_server::listen;
use live_server::{listen, hard_reload};

/// Launch a local network server with live reload feature for static pages.
#[derive(Parser)]
Expand All @@ -18,6 +18,11 @@ struct Args {
/// Open the page in browser automatically
#[clap(short, long)]
open: bool,
/// Hard reload the page on update instead of hot reload
///
/// Try using this if the reload is not working as expected
#[clap(long)]
hard: bool,
}

#[tokio::main]
Expand All @@ -30,6 +35,7 @@ async fn main() {
port,
root,
open,
hard,
} = Args::parse();

let addr = format!("{}:{}", host, port);
Expand All @@ -40,5 +46,9 @@ async fn main() {
open::that(link).unwrap();
}

if hard {
hard_reload();
}

listener.start().await.unwrap();
}
90 changes: 66 additions & 24 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::ErrorKind;
use std::{fs, net::IpAddr};

use axum::extract::ws::WebSocket;
use axum::{
body::Body,
extract::{ws::Message, Request, WebSocketUpgrade},
Expand All @@ -12,7 +13,7 @@ use futures::{sink::SinkExt, stream::StreamExt};
use local_ip_address::local_ip;
use tokio::net::TcpListener;

use crate::{ADDR, ROOT, TX};
use crate::{ADDR, HARD, ROOT, TX};

pub(crate) async fn serve(tcp_listener: TcpListener, router: Router) {
axum::serve(tcp_listener, router).await.unwrap();
Expand Down Expand Up @@ -64,40 +65,52 @@ pub(crate) fn create_server() -> Router {
ws.on_failed_upgrade(|error| {
log::error!("Failed to upgrade websocket: {}", error);
})
.on_upgrade(|socket| async move {
let (mut sender, mut receiver) = socket.split();
let tx = TX.get().unwrap();
let mut rx = tx.subscribe();
let mut send_task = tokio::spawn(async move {
while rx.recv().await.is_ok() {
sender.send(Message::Text(String::new())).await.unwrap();
}
});
let mut recv_task =
tokio::spawn(
async move { while let Some(Ok(_)) = receiver.next().await {} },
);
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
})
.on_upgrade(on_websocket_upgrade)
}),
)
}

async fn on_websocket_upgrade(socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
let tx = TX.get().unwrap();
let mut rx = tx.subscribe();
let mut send_task = tokio::spawn(async move {
while rx.recv().await.is_ok() {
sender.send(Message::Text(String::new())).await.unwrap();
}
});
let mut recv_task =
tokio::spawn(
async move { while let Some(Ok(_)) = receiver.next().await {} },
);
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
}

async fn static_assets(req: Request<Body>) -> (StatusCode, HeaderMap, Body) {
let addr = ADDR.get().unwrap();
let root = ROOT.get().unwrap();

let is_reload = req.uri().query().map(|x| x=="reload").unwrap_or(false);

// Get the path and mime of the static file.
let mut path = req.uri().path().to_string();
path.remove(0);
let mut path = root.join(path);
let uri_path = req.uri().path();
let mut path = root.join(&uri_path[1..]);
if path.is_dir() {
if !uri_path.ends_with('/') {
// redirect so parent links work correctly
let mut redirect = uri_path.to_string();
redirect.push('/');
let mut headers = HeaderMap::new();
headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
}
path.push("index.html");
}
let mime = mime_guess::from_path(&path).first_or_text_plain();

let mut headers = HeaderMap::new();
headers.append(
header::CONTENT_TYPE,
Expand All @@ -117,7 +130,7 @@ async fn static_assets(req: Request<Body>) -> (StatusCode, HeaderMap, Body) {
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
if mime == "text/html" {
let script = format!(include_str!("templates/websocket.html"), addr);
let script = format_script(addr, is_reload, true);
let html = format!(include_str!("templates/error.html"), script, err);
let body = Body::from(html);

Expand All @@ -137,9 +150,38 @@ async fn static_assets(req: Request<Body>) -> (StatusCode, HeaderMap, Body) {
return (StatusCode::INTERNAL_SERVER_ERROR, headers, body);
}
};
let script = format!(include_str!("templates/websocket.html"), addr);
let script = format_script(addr, is_reload, false);
file = format!("{text}{script}").into_bytes();
} else {
if !HARD.get().copied().unwrap_or(false) {
// allow client to cache assets for a smoother reload.
// client handles preloading to refresh cache before reloading.
headers.append(header::CACHE_CONTROL, HeaderValue::from_str("max-age=30").unwrap());
}
}


(StatusCode::OK, headers, Body::from(file))
}

/// JS script containg a function that takes in the address and connects to the websocket.
const WEBSOCKET_FUNCTION: &str = include_str!("templates/websocket.js");

/// JS script to inject to the HTML on reload so the client
/// knows it's a successful reload.
const RELOAD_PAYLOAD: &str = include_str!("templates/reload.js");

/// Inject the address into the websocket script and wrap it in a script tag
fn format_script(addr: &str, is_reload: bool, is_error: bool) -> String {
match (is_reload, is_error) {
// successful reload, inject the reload payload
(true, false) => format!("<script>{}</script>", RELOAD_PAYLOAD),
// failed reload, don't inject anything so the client polls again
(true, true) => String::new(),
// normal connection, inject the websocket client
_ => {
let hard = if HARD.get().copied().unwrap_or(false) { "true" } else { "false" };
format!(r#"<script>{}("{}", {})</script>"#, WEBSOCKET_FUNCTION, addr, hard)
}
}
}
5 changes: 5 additions & 0 deletions src/templates/reload.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const meta = document.createElement("meta");
meta.name = "live-server";
meta.content = "reload";
document.head.appendChild(meta);

6 changes: 0 additions & 6 deletions src/templates/websocket.html

This file was deleted.

87 changes: 87 additions & 0 deletions src/templates/websocket.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
(async (addr, hard) => {
addr = `ws://${addr}/live-server-ws`;
const sleep = (x) => new Promise(r => setTimeout(r, x));
const preload = async (url, requireSuccess) => {
const resp = await fetch(url, {cache: "reload"}); // reset cache
if (requireSuccess && (!resp.ok || resp.status !== 200)) { throw new Error(); }
}
/** Reset cache in link.href and strip scripts */
const preloadNode = (n, ps) => {
if (n.tagName === "SCRIPT" && n.src) { ps.push(preload(n.src, false)); return; }
if (n.tagName === "LINK" && n.href) { ps.push(preload(n.href, false)); return; }
let c = n.firstChild; while (c) { let nc = c.nextSibling; preloadNode(c, ps); c = nc; }
};
let reloading = false;
let scheduled = false;
async function reload() {
if(reloading) {
scheduled = true;
return;
}
let ifr;
let interval = 0;
reloading = true;
while(true) {
scheduled = false;
try {
const url = location.origin + location.pathname;
const ps = [];
preloadNode(document.head, ps);
preloadNode(document.body, ps);
await Promise.allSettled(ps);
await new Promise((resolve) => {
ifr = document.createElement("iframe");
ifr.src = url + "?reload";
ifr.style.display = "none";
ifr.onload = resolve;
document.body.appendChild(ifr);
});
const meta = ifr.contentDocument.head.lastChild;
if (meta.tagName !== "META" || meta.name !== "live-server" || meta.content !== "reload") {
throw new Error();
}
if (hard) {
location.reload();
}
document.head.replaceWith(ifr.contentDocument.head);
document.body.replaceWith(ifr.contentDocument.body);
ifr.remove();
if (!scheduled) {
reloading = false;
console.log("[Live Server] Reloaded");
return;
}
} catch (e) {
if (e.message) { console.error(e); }
}
if (ifr) { ifr.remove(); }
interval += 100;
await sleep(interval);
}
}
// connection
let isFirst = true;
let interval = 0;
while (true) {
try {
await new Promise((resolve) => {
const ws = new WebSocket(addr);
ws.onopen = ()=>{
interval = 0;
console.log("[Live Server] Connection Established");
if (!isFirst) {
reload();
}
};
ws.onmessage = reload;
ws.onerror = () =>ws.close();
ws.onclose = resolve;
});
} catch (e) {
}
isFirst = false;
interval += 500;
await sleep(interval);
console.log("[Live Server] Reconnecting...");
}
})
37 changes: 32 additions & 5 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ async fn request() {

let text = response.text().await.unwrap().replace("\r\n", "\n");
let target_text = format!(
"{}{}",
r#"{}<script>{}("{}", false)</script>"#,
include_str!("./page/index.html"),
format!(
include_str!("../src/templates/websocket.html"),
"127.0.0.1:8000"
)
include_str!("../src/templates/websocket.js"),
"127.0.0.1:8000"
)
.replace("\r\n", "\n");
assert_eq!(text, target_text);
assert!(text.contains("<script>"));

// Test requesting index.js
let response = reqwest::get("http://127.0.0.1:8000/index.js")
Expand Down Expand Up @@ -64,4 +63,32 @@ async fn request() {

let content_type = response.headers().get("content-type").unwrap();
assert_eq!(content_type, "image/x-icon");

// Test requesting with reload query
let response = reqwest::get("http://127.0.0.1:8000?reload").await.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let content_type = response.headers().get("content-type").unwrap();
assert_eq!(content_type, "text/html");

let text = response.text().await.unwrap().replace("\r\n", "\n");
let target_text = format!(
r#"{}<script>{}</script>"#,
include_str!("./page/index.html"),
include_str!("../src/templates/reload.js"),
)
.replace("\r\n", "\n");
assert_eq!(text, target_text);

// Test requesting non-existent html file with reload query does not inject script
let response = reqwest::get("http://127.0.0.1:8000/404.html?reload").await.unwrap();

assert_eq!(response.status(), StatusCode::NOT_FOUND);

let content_type = response.headers().get("content-type").unwrap();
assert_eq!(content_type, "text/html");

let text = response.text().await.unwrap();
assert!(!text.contains("<script>"));
}

0 comments on commit f6e77f8

Please sign in to comment.