Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly implement LNURL errors #17

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 27 additions & 30 deletions src/lnurlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use nostr::{Event, JsonUtil, Kind};

use crate::routes::{LnurlStatus, LnurlType, LnurlWellKnownResponse};

const INVALID_AMT_ERR: &str = "Invalid amount. Make sure the amount is within the range.";

fn calc_metadata(name: &str, domain: &str) -> String {
format!("[[\"text/identifier\",\"{name}@{domain}\"],[\"text/plain\",\"Sats for {name}\"]]")
}
Expand All @@ -27,7 +29,7 @@ pub async fn well_known_lnurlp(
) -> anyhow::Result<LnurlWellKnownResponse> {
let user = state.db.get_user_by_name(name.clone())?;
if user.is_none() {
return Err(anyhow!("NotFound"));
return Err(anyhow!("Not Found"));
}

let res = LnurlWellKnownResponse {
Expand Down Expand Up @@ -55,22 +57,17 @@ pub async fn lnurl_callback(
) -> anyhow::Result<LnurlCallbackResponse> {
let user = state.db.get_user_and_increment_counter(&name)?;
if user.is_none() {
return Err(anyhow!("NotFound"));
return Err(anyhow!("Not Found"));
}
let user = user.expect("just checked");

if params.amount < MIN_AMOUNT {
return Err(anyhow::anyhow!(
"Amount ({}) < MIN_AMOUNT ({MIN_AMOUNT})",
params.amount
));
}
let amount_msats = match params.amount {
Some(amt) => amt,
None => return Err(anyhow!(INVALID_AMT_ERR)),
};

if params.amount > MAX_AMOUNT {
return Err(anyhow::anyhow!(
"Amount ({}) < MAX_AMOUNT ({MAX_AMOUNT})",
params.amount
));
if !(MIN_AMOUNT..=MAX_AMOUNT).contains(&amount_msats) {
return Err(anyhow::anyhow!(INVALID_AMT_ERR));
}

// verify nostr param is a zap request if we have one
Expand All @@ -84,13 +81,13 @@ pub async fn lnurl_callback(
}

let federation_id = FederationId::from_str(&user.federation_id)
.map_err(|e| anyhow::anyhow!("Invalid federation_id: {e}"))?;
.map_err(|e| anyhow::anyhow!("Internal error: Invalid federation_id: {e}"))?;

let client = state
.mm
.get_federation_client(federation_id)
.await
.ok_or(anyhow!("NotFound"))?;
.ok_or(anyhow!("Internal error: No federation client"))?;

let ln = client.get_first_module::<LightningClientModule>();

Expand All @@ -107,11 +104,11 @@ pub async fn lnurl_callback(

let gateway = select_gateway(&client)
.await
.ok_or(anyhow!("No gateway found for federation"))?;
.ok_or(anyhow!("Internal error: No gateway found for federation"))?;

let (op_id, pr, preimage) = ln
.create_bolt11_invoice_for_user_tweaked(
Amount::from_msats(params.amount),
Amount::from_msats(amount_msats),
Bolt11InvoiceDescription::Hash(&desc_hash),
Some(86_400), // 1 day expiry
user.pubkey().public_key(Parity::Even),
Expand All @@ -129,7 +126,7 @@ pub async fn lnurl_callback(
app_user_id: user.id,
user_invoice_index: invoice_index,
bolt11: pr.to_string(),
amount: params.amount as i64,
amount: amount_msats as i64,
state: InvoiceState::Pending as i32,
};

Expand All @@ -156,7 +153,7 @@ pub async fn lnurl_callback(
let verify_url = format!("{}/lnurlp/{}/verify/{}", state.domain, user.name, op_id);

Ok(LnurlCallbackResponse {
pr: pr.to_string(),
pr,
success_action: None,
status: LnurlStatus::Ok,
reason: None,
Expand All @@ -173,15 +170,15 @@ pub async fn verify(
let invoice = state
.db
.get_invoice_by_op_id(op_id)?
.ok_or(anyhow::anyhow!("NotFound"))?;
.ok_or(anyhow::anyhow!("Not Found"))?;

let user = state
.db
.get_user_by_name(name)?
.ok_or(anyhow::anyhow!("NotFound"))?;
.ok_or(anyhow::anyhow!("Not Found"))?;

if invoice.app_user_id != user.id {
return Err(anyhow::anyhow!("NotFound"));
return Err(anyhow::anyhow!("Not Found"));
}

let verify_response = LnurlVerifyResponse {
Expand Down Expand Up @@ -309,23 +306,23 @@ mod tests_integration {
state.db.insert_new_user(user).unwrap();

let params = LnurlCallbackParams {
amount: 1,
amount: Some(1),
..Default::default()
};

match lnurl_callback(&state, username.clone(), params).await {
Ok(_) => panic!("unexpected ok"),
Err(e) => assert!(e.to_string().contains("MIN_AMOUNT")),
Err(e) => assert_eq!(e.to_string(), INVALID_AMT_ERR),
}

let params = LnurlCallbackParams {
amount: u64::MAX,
amount: Some(u64::MAX),
..Default::default()
};

match lnurl_callback(&state, username, params).await {
Ok(_) => panic!("unexpected ok"),
Err(e) => assert!(e.to_string().contains("MAX_AMOUNT")),
Err(e) => assert_eq!(e.to_string(), INVALID_AMT_ERR),
}
}

Expand Down Expand Up @@ -378,7 +375,7 @@ mod tests_integration {
state.db.insert_new_user(user).unwrap();

let params = LnurlCallbackParams {
amount: 10_000,
amount: Some(10_000),
nonce: None,
comment: None,
proofofpayer: None,
Expand All @@ -388,7 +385,7 @@ mod tests_integration {
match lnurl_callback(&state, username, params).await {
Ok(result) => {
assert_eq!(result.status, LnurlStatus::Ok);
assert!(!result.pr.is_empty());
assert!(!result.pr.is_expired());
}
Err(e) => panic!("shouldn't error: {e}"),
}
Expand Down Expand Up @@ -450,7 +447,7 @@ mod tests_integration {
};

let params = LnurlCallbackParams {
amount: 10_000,
amount: Some(10_000),
nonce: None,
comment: None,
proofofpayer: None,
Expand All @@ -460,7 +457,7 @@ mod tests_integration {
match lnurl_callback(&state, username, params).await {
Ok(result) => {
assert_eq!(result.status, LnurlStatus::Ok);
assert!(!result.pr.is_empty());
assert!(!result.pr.is_expired());
}
Err(e) => panic!("shouldn't error: {e}"),
}
Expand Down
24 changes: 20 additions & 4 deletions src/nostr.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
use axum::http::StatusCode;
use axum::Json;
use nostr::prelude::XOnlyPublicKey;
use serde_json::{json, Value};
use std::{collections::HashMap, str::FromStr};

use crate::State;

pub fn well_known_nip5(
state: &State,
name: String,
) -> anyhow::Result<HashMap<String, XOnlyPublicKey>> {
let user = state.db.get_user_by_name(name)?;
) -> Result<HashMap<String, XOnlyPublicKey>, (StatusCode, Json<Value>)> {
let user = state.db.get_user_by_name(name).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"status": "ERROR", "error": e.to_string()})),
)
})?;

let mut names = HashMap::new();
if let Some(user) = user {
names.insert(user.name, XOnlyPublicKey::from_str(&user.pubkey)?);
names.insert(
user.name,
XOnlyPublicKey::from_str(&user.pubkey).expect("valid npub"),
);
} else {
return Err((
StatusCode::NOT_FOUND,
Json(json!({"status": "ERROR", "error": "Not Found"})),
));
}

Ok(names)
Expand Down Expand Up @@ -78,7 +94,7 @@ mod tests_integration {
Ok(result) => {
assert_eq!(result.get(&username).unwrap().to_string(), pk1.to_string());
}
Err(e) => panic!("shouldn't error: {e}"),
Err((_code, json)) => panic!("shouldn't error: {json:?}"),
}
}
}
58 changes: 44 additions & 14 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,32 @@ use crate::{
use axum::extract::{Path, Query};
use axum::headers::Origin;
use axum::http::StatusCode;
use axum::response::Redirect;
use axum::response::{IntoResponse, Redirect, Response};
use axum::Extension;
use axum::{Json, TypedHeader};
use fedimint_core::Amount;
use log::{debug, error};
use nostr::prelude::XOnlyPublicKey;
use serde::{de, Deserialize, Deserializer, Serialize};
use serde_json::{json, Value};
use std::{collections::HashMap, fmt::Display, str::FromStr};
use fedimint_ln_common::lightning_invoice::Bolt11Invoice;
use tbs::AggregatePublicKey;
use url::Url;

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LnUrlErrorResponse {
pub status: LnurlStatus,
pub reason: String,
}

impl IntoResponse for LnUrlErrorResponse {
fn into_response(self) -> Response {
let body = serde_json::to_value(self).expect("valid json");
(StatusCode::OK, Json(body)).into_response()
}
}

pub async fn check_username(
origin: Option<TypedHeader<Origin>>,
Extension(state): Extension<State>,
Expand Down Expand Up @@ -85,7 +100,7 @@ pub async fn register_route(

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct UserWellKnownNip5Req {
pub name: String,
pub name: Option<String>,
}

#[derive(Deserialize, Serialize, Debug, Clone)]
Expand All @@ -96,11 +111,17 @@ pub struct UserWellKnownNip5Resp {
pub async fn well_known_nip5_route(
Extension(state): Extension<State>,
Query(params): Query<UserWellKnownNip5Req>,
) -> Result<Json<UserWellKnownNip5Resp>, (StatusCode, String)> {
) -> Result<Json<UserWellKnownNip5Resp>, (StatusCode, Json<Value>)> {
debug!("well_known_nip5_route");
match well_known_nip5(&state, params.name) {
Ok(res) => Ok(Json(UserWellKnownNip5Resp { names: res })),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
match params.name {
Some(name) => {
let names = well_known_nip5(&state, name)?;
Ok(Json(UserWellKnownNip5Resp { names }))
}
None => Err((
StatusCode::NOT_FOUND,
Json(json!({"status": "ERROR", "error": "Not Found"})),
)),
}
}

Expand Down Expand Up @@ -136,18 +157,21 @@ pub struct LnurlWellKnownResponse {
pub async fn well_known_lnurlp_route(
Extension(state): Extension<State>,
Path(username): Path<String>,
) -> Result<Json<LnurlWellKnownResponse>, (StatusCode, String)> {
) -> Result<Json<LnurlWellKnownResponse>, LnUrlErrorResponse> {
debug!("well_known_lnurlp_route");
match well_known_lnurlp(&state, username).await {
Ok(res) => Ok(Json(res)),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
Err(e) => Err(LnUrlErrorResponse {
status: LnurlStatus::Error,
reason: e.to_string(),
}),
}
}

#[derive(Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct LnurlCallbackParams {
pub amount: u64, // User specified amount in MilliSatoshi
pub amount: Option<u64>, // User specified amount in MilliSatoshi
#[serde(default, deserialize_with = "empty_string_as_none")]
pub nonce: Option<String>, // Optional parameter used to prevent server response caching
#[serde(default, deserialize_with = "empty_string_as_none")]
Expand All @@ -164,7 +188,7 @@ pub struct LnurlCallbackResponse {
pub status: LnurlStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
pub pr: String, // BOLT11 invoice
pub pr: Bolt11Invoice,
pub verify: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub success_action: Option<LnurlCallbackSuccessAction>,
Expand All @@ -183,11 +207,14 @@ pub async fn lnurl_callback_route(
Extension(state): Extension<State>,
Query(params): Query<LnurlCallbackParams>,
Path(username): Path<String>,
) -> Result<Json<LnurlCallbackResponse>, (StatusCode, String)> {
) -> Result<Json<LnurlCallbackResponse>, LnUrlErrorResponse> {
debug!("lnurl_callback_route");
match lnurl_callback(&state, username, params).await {
Ok(res) => Ok(Json(res)),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
Err(e) => Err(LnUrlErrorResponse {
status: LnurlStatus::Error,
reason: e.to_string(),
}),
}
}

Expand All @@ -203,11 +230,14 @@ pub struct LnurlVerifyResponse {
pub async fn lnurl_verify_route(
Extension(state): Extension<State>,
Path((username, op_id)): Path<(String, String)>,
) -> Result<Json<LnurlVerifyResponse>, (StatusCode, String)> {
) -> Result<Json<LnurlVerifyResponse>, LnUrlErrorResponse> {
debug!("lnurl_callback_route");
match verify(&state, username, op_id).await {
Ok(res) => Ok(Json(res)),
Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())),
Err(e) => Err(LnUrlErrorResponse {
status: LnurlStatus::Error,
reason: e.to_string(),
}),
}
}

Expand Down