Skip to content

Commit

Permalink
client split in sub-modules
Browse files Browse the repository at this point in the history
Signed-off-by: Leo Valais <[email protected]>
  • Loading branch information
leovalais committed Feb 5, 2025
1 parent 28052f1 commit 0ad30c3
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 299 deletions.
356 changes: 57 additions & 299 deletions fga/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
mod authorization_models;
mod queries;
mod stores;
mod tuples;

pub use authorization_models::AuthorizationModel;
pub use authorization_models::StoreAuthorizationModel;
pub use stores::Store;

use tuples::RawTuple;

use std::future::{self, Future};

use futures::{stream, TryStreamExt as _};
Expand Down Expand Up @@ -28,6 +39,8 @@ pub enum InitializationError {
}

// Public API of the client
// -------------------------

impl Client {
pub async fn try_init(
store_name: String,
Expand Down Expand Up @@ -227,8 +240,31 @@ impl Client {
}
}

pub type AuthorizationModel = serde_json::Value;
// Mapping of OpenFGA HTTP API
// ---------------------------
//
// Client functions are implemented for each OpenFGA endpoint. The implementations are
// scattered across different sub-modules, which are defined according to the sections
// of the OpenFGA API documentation: https://openfga.dev/api/service

impl Client {
fn base_url(&self) -> url::Url {
url::Url::parse(
format!("http://{}:{}/", self.settings.address, self.settings.port).as_str(),
)
.unwrap()
}
}

/// Convenience trait to query OpenFGA from [crate::model] query types directly
///
/// For example:
///
/// ```ignore
/// Object::relation().check(&user, &object).fetch(&client).await.unwrap();
/// // instead of
/// client.check(Object::relation().check(&user, &object)).await.unwrap();
/// ```
pub trait Request {
type Response;
type Error: std::error::Error;
Expand Down Expand Up @@ -263,304 +299,6 @@ impl From<reqwest::Error> for RequestFailure {
}
}

#[derive(Debug, Default, Clone, serde::Deserialize)]
pub struct Store {
pub id: String,
pub name: String,
pub created_at: String,
pub updated_at: String,
pub deleted_at: Option<String>,
}

#[derive(Debug, serde::Serialize)]
struct RawTuple {
user: String,
relation: String,
object: String,
}

impl<'a, R: Relation, U: AsUser<User = R::User>> From<&Tuple<'a, R, U>> for RawTuple {
fn from(tuple: &Tuple<'a, R, U>) -> Self {
RawTuple {
user: tuple.user.fga_ident(),
relation: R::NAME.to_string(),
object: tuple.object.fga_ident(),
}
}
}

#[derive(Debug, serde::Serialize)]
struct ContextualTuples {
tuple_keys: Vec<RawTuple>,
}

impl<'a, R: Relation, U: AsUser<User = R::User>> FromIterator<&'a Tuple<'a, R, U>>
for ContextualTuples
{
fn from_iter<I: IntoIterator<Item = &'a Tuple<'a, R, U>>>(iter: I) -> Self {
Self {
tuple_keys: iter.into_iter().map(RawTuple::from).collect(),
}
}
}

#[derive(Debug, serde::Deserialize)]
pub struct StoreAuthorizationModel {
pub id: String,
pub type_definitions: AuthorizationModel,
}

// Almost 1-to-1 mapping of the HTTP API
impl Client {
fn base_url(&self) -> url::Url {
url::Url::parse(
format!("http://{}:{}/", self.settings.address, self.settings.port).as_str(),
)
.unwrap()
}

async fn get_stores(
&self,
page_size: Option<usize>,
continuation: Option<&str>,
) -> Result<(Vec<Store>, String), RequestFailure> {
#[derive(serde::Deserialize)]
struct Response {
stores: Vec<Store>,
#[serde(default)]
continuation_token: String,
}

let mut url = self.base_url().join("stores").unwrap();
if let Some(continuation) = continuation {
url.query_pairs_mut()
.append_pair("continuation_token", continuation);
}
if let Some(page_size) = page_size {
url.query_pairs_mut()
.append_pair("page_size", page_size.to_string().as_str());
}
let response = self.inner.get(url).send().await?;

let Response {
stores,
continuation_token,
} = response.error_for_status()?.json::<Response>().await?;

Ok((stores, continuation_token))
}

async fn post_stores(&self, name: &str) -> Result<Store, RequestFailure> {
#[derive(serde::Serialize)]
struct Request {
name: String,
}

let request = Request {
name: name.to_owned(),
};

let url = self.base_url().join("stores").unwrap();
let response = self.inner.post(url).json(&request).send().await?;

let store = response.error_for_status()?.json().await?;
Ok(store)
}

async fn delete_stores(&self, store_id: &str) -> Result<(), RequestFailure> {
let url = self
.base_url()
.join(format!("stores/{store_id}").as_str())
.unwrap();
self.inner.delete(url).send().await?.error_for_status()?;
Ok(())
}

// It's fine to request tuples to be mapped into `RawTuple` as OpenFGA
// doesn't support more than 100 tuples in the request. So mapping 100 objects
// max is fine—we'll always be bounded by the network call.
async fn post_stores_write<'a>(
&self,
store_id: &str,
writes: &[RawTuple],
deletes: &[RawTuple],
authorization_model_id: Option<String>,
) -> Result<(), RequestFailure> {
#[derive(serde::Serialize)]
struct Request<'a> {
#[serde(skip_serializing_if = "Writes::is_empty")]
writes: Writes<'a>,
#[serde(skip_serializing_if = "Deletes::is_empty")]
deletes: Deletes<'a>,
#[serde(skip_serializing_if = "Option::is_none")]
authorization_model_id: Option<String>,
}

#[derive(serde::Serialize)]
struct Writes<'a> {
tuple_keys: &'a [RawTuple],
}

impl<'a> Writes<'a> {
fn is_empty(&self) -> bool {
self.tuple_keys.is_empty()
}
}

#[derive(serde::Serialize)]
struct Deletes<'a> {
tuple_keys: &'a [RawTuple],
}

impl<'a> Deletes<'a> {
fn is_empty(&self) -> bool {
self.tuple_keys.is_empty()
}
}

let url = self
.base_url()
.join(format!("stores/{store_id}/write").as_str())
.unwrap();
self.inner
.post(url)
.json(&Request {
writes: Writes { tuple_keys: writes },
deletes: Deletes {
tuple_keys: deletes,
},
authorization_model_id,
})
.send()
.await?
.error_for_status()?;
Ok(())
}

async fn get_stores_authorization_models(
&self,
store_id: &str,
page_size: Option<usize>,
continuation: Option<&str>,
) -> Result<(Vec<StoreAuthorizationModel>, String), RequestFailure> {
#[derive(serde::Deserialize)]
struct Response {
authorization_models: Vec<StoreAuthorizationModel>,
#[serde(default)]
continuation_token: String,
}

let mut url = self
.base_url()
.join(format!("/stores/{store_id}/authorization-models").as_str())
.unwrap();
if let Some(continuation) = continuation {
url.query_pairs_mut()
.append_pair("continuation_token", continuation);
}
if let Some(page_size) = page_size {
url.query_pairs_mut()
.append_pair("page_size", page_size.to_string().as_str());
}

let response = self.inner.get(url).send().await?.error_for_status()?;
let Response {
authorization_models,
continuation_token,
} = response.json().await?;

Ok((authorization_models, continuation_token))
}

async fn post_stores_authorization_models(
&self,
store_id: &str,
authorization_model: &AuthorizationModel,
) -> Result<String, RequestFailure> {
let url = self
.base_url()
.join(format!("stores/{store_id}/authorization-models").as_str())
.unwrap();
let response = self
.inner
.post(url)
.json(authorization_model)
.send()
.await?
.error_for_status()?;

#[derive(serde::Deserialize)]
struct Response {
authorization_model_id: String,
}
let Response {
authorization_model_id,
} = response.json().await?;

Ok(authorization_model_id)
}

async fn post_stores_check(
&self,
store_id: &str,
tuple: RawTuple,
contextual_tuples: Option<ContextualTuples>,
authorization_model_id: Option<String>,
) -> Result<bool, RequestFailure> {
#[derive(serde::Serialize)]
struct Request {
tuple_key: RawTuple,
#[serde(skip_serializing_if = "Option::is_none")]
contextual_tuples: Option<ContextualTuples>,
#[serde(skip_serializing_if = "Option::is_none")]
authorization_model_id: Option<String>,
}

let request = Request {
tuple_key: tuple,
contextual_tuples,
authorization_model_id,
};

let url = self
.base_url()
.join(format!("stores/{store_id}/check").as_str())
.unwrap();
let response = self.inner.post(url).json(&request).send().await?;

#[derive(serde::Deserialize)]
struct Response {
allowed: bool,
#[expect(dead_code)]
resolution: String,
}

let Response { allowed, .. } = response.error_for_status()?.json::<Response>().await?;

Ok(allowed)
}
}

/// Models the three states of a continuation while unfolding paginated API calls
enum Continuation {
/// Initial state, no calls have been made yet
None,
/// A call response has provided a continuation token
Continue(String),
/// A call response has provided no continuation token (an empty string) meaning that the pagination ends there
Stop,
}

impl<S: AsRef<str>> From<S> for Continuation {
fn from(s: S) -> Self {
if s.as_ref().is_empty() {
Continuation::Stop
} else {
Continuation::Continue(s.as_ref().to_owned())
}
}
}

/// Allows transforming continuation-based paginated endpoint calls into a [stream::TryStream]
///
/// The [ContinuationUnfolder::stream] function takes a closure that will be called repeatedly until
Expand Down Expand Up @@ -589,6 +327,26 @@ struct UnfoldNextState<C> {
continuation: String,
}

/// Models the three states of a continuation while unfolding paginated API calls
enum Continuation {
/// Initial state, no calls have been made yet
None,
/// A call response has provided a continuation token
Continue(String),
/// A call response has provided no continuation token (an empty string) meaning that the pagination ends there
Stop,
}

impl<S: AsRef<str>> From<S> for Continuation {
fn from(s: S) -> Self {
if s.as_ref().is_empty() {
Continuation::Stop
} else {
Continuation::Continue(s.as_ref().to_owned())
}
}
}

impl<C> ContinuationUnfolder<C> {
fn new(client: Client, ctx: C) -> Self {
Self {
Expand Down
Loading

0 comments on commit 0ad30c3

Please sign in to comment.