Skip to content

Commit

Permalink
flow-client: Refactor to extract out refresh logic into bare function…
Browse files Browse the repository at this point in the history
…, and re-use Postgrest connection pool instead of creating a new one on each call to `pg_client()`
  • Loading branch information
jshearer committed Oct 2, 2024
1 parent c6a27dc commit e78de08
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 80 deletions.
7 changes: 5 additions & 2 deletions crates/dekaf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ pub use api_client::KafkaApiClient;

use aes_siv::{aead::Aead, Aes256SivAead, KeyInit, KeySizeUser};
use connector::DekafConfig;
use flow_client::{client::RefreshToken, DEFAULT_AGENT_URL};
use flow_client::{
client::{refresh_client, RefreshToken},
DEFAULT_AGENT_URL,
};
use percent_encoding::{percent_decode_str, utf8_percent_encode};
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
Expand Down Expand Up @@ -77,7 +80,7 @@ impl App {
Some(refresh),
);

client.refresh().await?;
refresh_client(&mut client).await?;
let claims = client.claims()?;

if models::Materialization::regex().is_match(username.as_ref()) {
Expand Down
154 changes: 77 additions & 77 deletions crates/flow-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ pub struct Client {
agent_endpoint: url::Url,
// HTTP client to use for REST requests.
http_client: reqwest::Client,
// PostgREST URL.
pg_url: url::Url,
// PostgREST access token.
pg_api_token: String,
// User's access token, if authenticated.
user_access_token: Option<String>,
// User's refresh token, if authenticated.
Expand All @@ -23,6 +19,10 @@ pub struct Client {
shard_client: gazette::shard::Client,
// Base journal client which is cloned to build token-specific clients.
journal_client: gazette::journal::Client,
// Keep a single Postgrest and hand out clones of it in order to maintain
// a single connection pool. The clones can have different headers while
// still re-using the same connection pool, so this will work across refreshes.
pg_parent: postgrest::Postgrest,
}

impl Client {
Expand Down Expand Up @@ -54,90 +54,24 @@ impl Client {
Self {
agent_endpoint,
http_client: reqwest::Client::new(),
pg_api_token,
pg_url,
pg_parent: postgrest::Postgrest::new(pg_url.as_str())
.insert_header("apikey", pg_api_token.as_str()),
journal_client,
shard_client,
user_access_token,
user_refresh_token: userrefresh_token,
}
}

pub async fn refresh(&mut self) -> anyhow::Result<()> {
// Clear expired or soon-to-expire access token
if let Some(_) = &self.user_access_token {
let claims = self.claims()?;

let now = time::OffsetDateTime::now_utc();
let exp = time::OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();

// Refresh access tokens with plenty of time to spare if we have a
// refresh token. If not, allow refreshing right until the token expires
match ((now - exp).whole_seconds(), &self.user_refresh_token) {
(exp_seconds, Some(_)) if exp_seconds < 60 => self.user_access_token = None,
(exp_seconds, None) if exp_seconds <= 0 => self.user_access_token = None,
_ => {}
}
}

if self.user_access_token.is_some() && self.user_refresh_token.is_some() {
// Authorization is current: nothing to do.
Ok(())
} else if self.user_access_token.is_some() {
// We have an access token but no refresh token. Create one.
let refresh_token = api_exec::<RefreshToken>(
self.rpc(
"create_refresh_token",
serde_json::json!({"multi_use": true, "valid_for": "90d", "detail": "Created by flowctl"})
.to_string(),
),
)
.await?;

self.user_refresh_token = Some(refresh_token);

tracing::info!("created new refresh token");
Ok(())
} else if let Some(RefreshToken { id, secret }) = &self.user_refresh_token {
// We have a refresh token but no access token. Generate one.

#[derive(serde::Deserialize)]
struct Response {
access_token: String,
refresh_token: Option<RefreshToken>, // Set iff the token was single-use.
}
let Response {
access_token,
refresh_token: next_refresh_token,
} = api_exec::<Response>(self.rpc(
"generate_access_token",
serde_json::json!({"refresh_token_id": id, "secret": secret}).to_string(),
))
.await
.context("failed to obtain access token")?;

if next_refresh_token.is_some() {
self.user_refresh_token = next_refresh_token;
}

self.user_access_token = Some(access_token);

tracing::info!("generated a new access token");
Ok(())
} else {
anyhow::bail!("Client not authenticated");
}
}

pub fn pg_client(&self) -> postgrest::Postgrest {
let pg_client = postgrest::Postgrest::new(self.pg_url.as_str())
.insert_header("apikey", self.pg_api_token.as_str());

if let Some(token) = &self.user_access_token {
return pg_client.insert_header("Authorization", &format!("Bearer {token}"));
return self
.pg_parent
.clone()
.insert_header("Authorization", &format!("Bearer {token}"));
}

pg_client
self.pg_parent.clone()
}

pub fn claims(&self) -> anyhow::Result<ControlClaims> {
Expand Down Expand Up @@ -331,3 +265,69 @@ pub async fn fetch_collection_authorization(

Ok((journal_name_prefix, journal_client))
}

pub async fn refresh_client(client: &mut Client) -> anyhow::Result<()> {
// Clear expired or soon-to-expire access token
if let Some(_) = &client.user_access_token {
let claims = client.claims()?;

let now = time::OffsetDateTime::now_utc();
let exp = time::OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();

// Refresh access tokens with plenty of time to spare if we have a
// refresh token. If not, allow refreshing right until the token expires
match ((now - exp).whole_seconds(), &client.user_refresh_token) {
(exp_seconds, Some(_)) if exp_seconds < 60 => client.user_access_token = None,
(exp_seconds, None) if exp_seconds <= 0 => client.user_access_token = None,
_ => {}
}
}

if client.user_access_token.is_some() && client.user_refresh_token.is_some() {
// Authorization is current: nothing to do.
Ok(())
} else if client.user_access_token.is_some() {
// We have an access token but no refresh token. Create one.
let refresh_token = api_exec::<RefreshToken>(
client.rpc(
"create_refresh_token",
serde_json::json!({"multi_use": true, "valid_for": "90d", "detail": "Created by flowctl"})
.to_string(),
),
)
.await?;

client.user_refresh_token = Some(refresh_token);

tracing::info!("created new refresh token");
Ok(())
} else if let Some(RefreshToken { id, secret }) = &client.user_refresh_token {
// We have a refresh token but no access token. Generate one.

#[derive(serde::Deserialize)]
struct Response {
access_token: String,
refresh_token: Option<RefreshToken>, // Set iff the token was single-use.
}
let Response {
access_token,
refresh_token: next_refresh_token,
} = api_exec::<Response>(client.rpc(
"generate_access_token",
serde_json::json!({"refresh_token_id": id, "secret": secret}).to_string(),
))
.await
.context("failed to obtain access token")?;

if next_refresh_token.is_some() {
client.user_refresh_token = next_refresh_token;
}

client.user_access_token = Some(access_token);

tracing::info!("generated a new access token");
Ok(())
} else {
anyhow::bail!("Client not authenticated");
}
}
3 changes: 2 additions & 1 deletion crates/flowctl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod poll;
mod preview;
mod raw;

use flow_client::client::refresh_client;
pub(crate) use flow_client::client::Client;
pub(crate) use flow_client::{api_exec, api_exec_paginated, parse_jwt_claims};
use output::{Output, OutputType};
Expand Down Expand Up @@ -151,7 +152,7 @@ impl Cli {
let mut client: flow_client::Client = config.build_client();

if config.user_access_token.is_some() || config.user_refresh_token.is_some() {
client.refresh().await?;
refresh_client(&mut client).await?;
} else {
tracing::warn!("You are not authenticated. Run `auth login` to login to Flow.");
}
Expand Down

0 comments on commit e78de08

Please sign in to comment.