Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: more metrics tests #22

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,27 +102,65 @@ impl Metrics {
self.prom_request_duration.with_label_values(&[method, &status.to_string()]).observe(duration);
self.prom_requests_in_flight.with_label_values(&[method]).dec();
}

/// Returns a vector of metric families from the Prometheus registry
pub fn get_metrics(&self) -> Vec<prometheus::proto::MetricFamily> {
self.registry.gather()
}

/// Returns an iterator over metrics with helper methods to find specific metrics
pub fn metrics_iter(&self) -> MetricsIterator {
MetricsIterator {
metrics: self.get_metrics()
}
}
}

async fn metrics_handler(metrics: Arc<Metrics>) -> std::result::Result<Response<Body>, Infallible> {
let encoder = TextEncoder::new();
let metric_families = metrics.registry.gather();
let mut buffer = Vec::new();
encoder.encode(&metric_families, &mut buffer).unwrap();
/// Helper struct to iterate and find metrics easily
pub struct MetricsIterator {
metrics: Vec<prometheus::proto::MetricFamily>
}

Ok(Response::builder()
.header("Content-Type", "text/plain")
.body(Body::from(buffer))
.unwrap())
impl MetricsIterator {
/// Find a metric by name
pub fn find_metric(&self, name: &str) -> Option<&prometheus::proto::MetricFamily> {
self.metrics.iter().find(|m| m.get_name() == name)
}

/// Get all metrics
pub fn all(&self) -> &[prometheus::proto::MetricFamily] {
&self.metrics
}
}

async fn metrics_handler(req: Request<Body>, metrics: Arc<Metrics>) -> std::result::Result<Response<Body>, Infallible> {
// Only respond to /metrics path
match req.uri().path() {
"/metrics" => {
let encoder = TextEncoder::new();
let metric_families = metrics.registry.gather();
let mut buffer = Vec::new();
encoder.encode(&metric_families, &mut buffer).unwrap();

Ok(Response::builder()
.header("Content-Type", "text/plain")
.body(Body::from(buffer))
.unwrap())
}
_ => Ok(Response::builder()
.status(404)
.body(Body::from("Not Found"))
.unwrap())
}
}

pub async fn run_metrics_server(metrics: Arc<Metrics>, addr: SocketAddr) -> std::result::Result<(), Box<dyn std::error::Error>> {
// Create the service
let make_svc = make_service_fn(move |_conn| {
let metrics = metrics.clone();
async move {
Ok::<_, Infallible>(service_fn(move |_req: Request<Body>| {
metrics_handler(metrics.clone())
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
metrics_handler(req, metrics.clone())
}))
}
});
Expand Down
44 changes: 31 additions & 13 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,22 @@ async fn test_server_basic_functionality() -> Result<(), Box<dyn std::error::Err

// Start server in background task
let test_port = 3001;
let metrics_port = 13001;
let addr = format!("127.0.0.1:{}", test_port);
let server_handle = tokio::spawn(async move {
let args = Args {
index_path: temp_file.path().to_str().unwrap().to_string(),
port: test_port,
addr: "127.0.0.1".to_string(),
metrics_port: 13001,
};
let metrics_addr = format!("127.0.0.1:{}", metrics_port).parse()?;

let html_content = fs::read_to_string(&args.index_path).unwrap();
let state = Arc::new(AppState::new(html_content));
let metrics = Arc::new(metrics::Metrics::new());

let html_content = fs::read_to_string(&temp_file.path().to_str().unwrap())?;
let state = Arc::new(AppState::new(html_content));
let metrics = Arc::new(metrics::Metrics::new());

// Start metrics server
let metrics_clone = metrics.clone();
let metrics_handle = tokio::spawn(async move {
metrics::run_metrics_server(metrics_clone, metrics_addr).await.unwrap();
});

// Start main server
let server_handle = tokio::spawn(async move {
let addr: SocketAddr = addr.parse().unwrap();
let make_svc = make_service_fn(move |_conn| {
let state = state.clone();
Expand All @@ -95,7 +98,7 @@ async fn test_server_basic_functionality() -> Result<(), Box<dyn std::error::Err
server.await.unwrap();
});

// Give the server a moment to start
// Give the servers a moment to start
sleep(Duration::from_millis(100)).await;

// Create a client
Expand Down Expand Up @@ -131,7 +134,22 @@ async fn test_server_basic_functionality() -> Result<(), Box<dyn std::error::Err
// Verify content matches
assert_eq!(body_string, test_content);

// Clean up
// Verify metrics
let metrics_response = client
.get(format!("http://127.0.0.1:{}/metrics", metrics_port).parse()?)
.await?;

assert_eq!(metrics_response.status(), 200);
let metrics_body = hyper::body::to_bytes(metrics_response.into_body()).await?;
let metrics_str = String::from_utf8(metrics_body.to_vec())?;

// Verify request was counted in metrics
assert!(metrics_str.contains("http_requests_total{method=\"GET\"} 1"));
assert!(metrics_str.contains("method=\"GET\""));
assert!(metrics_str.contains("http_request_duration_seconds"));

// Clean up both servers
metrics_handle.abort();
server_handle.abort();

Ok(())
Expand Down
141 changes: 141 additions & 0 deletions tests/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use std::sync::Arc;

use single_page_web_server_rs::metrics::Metrics;

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use std::thread;

#[test]
fn test_metrics_recording() {
let metrics = Arc::new(Metrics::new());

// Simulate multiple requests
for _ in 0..5 {
metrics.record_request("GET");
thread::sleep(Duration::from_millis(10)); // Simulate some work
metrics.record_response("GET", 200, std::time::Instant::now());
}

for _ in 0..3 {
metrics.record_request("POST");
thread::sleep(Duration::from_millis(10));
metrics.record_response("POST", 404, std::time::Instant::now());
}

// Gather Prometheus metrics
let metric_families = metrics.get_metrics();

// Helper function to find metric by name
let find_metric = |name: &str| {
metric_families.iter()
.find(|m| m.get_name() == name)
.expect(&format!("Metric {} not found", name))
};

// Verify request counts
let requests_total = find_metric("http_requests_total");
let get_requests = requests_total.get_metric().iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "GET"))
.expect("GET requests not found");
let post_requests = requests_total.get_metric().iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "POST"))
.expect("POST requests not found");

assert_eq!(get_requests.get_counter().get_value() as i64, 5);
assert_eq!(post_requests.get_counter().get_value() as i64, 3);

// Verify in-flight requests (should be 0 after all requests completed)
let requests_in_flight = find_metric("http_requests_in_flight");
for metric in requests_in_flight.get_metric() {
assert_eq!(metric.get_gauge().get_value() as i64, 0);
}

// Verify duration histogram
let duration = find_metric("http_request_duration_seconds");
let get_200_duration = duration.get_metric().iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "GET") &&
m.get_label().iter().any(|l| l.get_value() == "200"))
.expect("GET 200 duration not found");
let post_404_duration = duration.get_metric().iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "POST") &&
m.get_label().iter().any(|l| l.get_value() == "404"))
.expect("POST 404 duration not found");

assert!(get_200_duration.get_histogram().get_sample_count() == 5);
assert!(post_404_duration.get_histogram().get_sample_count() == 3);
}

#[test]
fn test_concurrent_requests() {
use std::thread;

let metrics = Arc::new(Metrics::new());
let mut handles = vec![];

// Spawn 10 threads making concurrent requests
for i in 0..10 {
let metrics = metrics.clone();
let handle = thread::spawn(move || {
let method = if i % 2 == 0 { "GET" } else { "POST" };
metrics.record_request(method);
thread::sleep(Duration::from_millis(5));
metrics.record_response(method, 200, std::time::Instant::now());
});
handles.push(handle);
}

// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}

// Verify total request count
let metric_families = metrics.get_metrics();
let requests_total = metric_families.iter()
.find(|m| m.get_name() == "http_requests_total")
.unwrap();

let total_requests: u64 = requests_total.get_metric().iter()
.map(|m| m.get_counter().get_value() as u64)
.sum();

assert_eq!(total_requests, 10);

// Verify no requests are in flight
let requests_in_flight = metric_families.iter()
.find(|m| m.get_name() == "http_requests_in_flight")
.unwrap();

let total_in_flight: i64 = requests_in_flight.get_metric().iter()
.map(|m| m.get_gauge().get_value() as i64)
.sum();

assert_eq!(total_in_flight, 0);
}
}

#[test]
fn test_metrics_iterator() {
let metrics = Arc::new(Metrics::new());

// Record some data
metrics.record_request("GET");
metrics.record_response("GET", 200, std::time::Instant::now());

// Use the iterator
let iter = metrics.metrics_iter();

// Find specific metric
let requests_total = iter.find_metric("http_requests_total")
.expect("http_requests_total metric should exist");

// Verify the metric
let get_requests = requests_total.get_metric().iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "GET"))
.expect("GET requests not found");

assert_eq!(get_requests.get_counter().get_value() as i64, 1);
}
Loading