Skip to content

Commit

Permalink
feat(websocket): add support for merge update by user id (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
kasugamirai authored Nov 15, 2024
1 parent ea69b8a commit 0793550
Show file tree
Hide file tree
Showing 14 changed files with 346 additions and 337 deletions.
Binary file modified websocket/.DS_Store
Binary file not shown.
58 changes: 33 additions & 25 deletions websocket/app/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ struct FlowMessage {
pub struct WebSocketQuery {
token: String,
user_id: String,
user_email: String,
user_name: String,
tenant_id: String,
project_id: Option<String>,
}

pub async fn handle_upgrade(
Expand All @@ -49,12 +47,7 @@ pub async fn handle_upgrade(
) -> impl IntoResponse {
debug!("{:?}", query);

let user = User {
id: query.user_id.clone(),
email: query.user_email.clone(),
name: query.user_name.clone(),
tenant_id: query.tenant_id.clone(),
};
let user = User::new(query.user_id.clone(), None, None);

ws.on_upgrade(move |socket| {
handle_socket(
Expand All @@ -63,7 +56,7 @@ pub async fn handle_upgrade(
query.token.to_string(),
room_id,
state,
None,
query.project_id.clone(),
user,
)
})
Expand All @@ -78,21 +71,11 @@ async fn handle_socket(
project_id: Option<String>,
user: User,
) {
if socket.send(Message::Ping(vec![4])).await.is_ok() {
println!("pinned to {addr}");
} else {
println!("couldn't ping to {addr}");
return;
}

// TODO: authentication
if token != "nyaan" {
if !verify_connection(&mut socket, &addr, &token).await {
return;
}

debug!("{:?}", state.make_room(room_id.clone()));
if let Err(e) = state.join(&room_id, &user.id).await {
debug!("Failed to join room: {:?}", e);
if !initialize_room(&state, &room_id, &user).await {
return;
}

Expand Down Expand Up @@ -137,6 +120,31 @@ async fn handle_socket(
}
}

async fn verify_connection(socket: &mut WebSocket, addr: &SocketAddr, token: &str) -> bool {
if socket.send(Message::Ping(vec![4])).await.is_err() || token != "nyaan" {
debug!("Connection failed for {addr}: ping failed or invalid token");
return false;
}
true
}

async fn initialize_room(state: &Arc<AppState>, room_id: &str, user: &User) -> bool {
match state.make_room(room_id.to_string()) {
Ok(_) => debug!("Room created/exists: {}", room_id),
Err(e) => {
debug!("Failed to create room: {:?}", e);
return false;
}
}

if let Err(e) = state.join(room_id, &user.id).await {
debug!("Failed to join room: {:?}", e);
return false;
}

true
}

async fn handle_message(
msg: Message,
addr: SocketAddr,
Expand Down Expand Up @@ -181,10 +189,10 @@ async fn handle_message(
if let Some(project_id) = project_id {
state
.command_tx
.send(SessionCommand::PushUpdate {
.send(SessionCommand::MergeUpdates {
project_id,
update: d,
updated_by: Some(user.name.clone()),
data: d,
updated_by: Some(user.id.clone()),
})
.await?;
}
Expand Down
34 changes: 26 additions & 8 deletions websocket/app/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,25 @@ type SessionService = ManageEditSessionService<
FlowProjectRedisDataManager,
>;

const DEFAULT_REDIS_URL: &str = "redis://localhost:6379/0";
const CHANNEL_BUFFER_SIZE: usize = 32;
#[cfg(feature = "local-storage")]
const DEFAULT_LOCAL_STORAGE_PATH: &str = "./local_storage";

#[derive(Clone)]
pub struct AppState {
pub rooms: Arc<Mutex<HashMap<String, Room>>>,
pub redis_pool: Pool<RedisConnectionManager>,
pub storage: Arc<ProjectStorageRepository>,
pub session_repo: Arc<ProjectRedisRepository>,
pub service: Arc<SessionService>,
pub redis_url: String,
pub command_tx: mpsc::Sender<SessionCommand>,
}

impl AppState {
pub async fn new(redis_url: Option<String>) -> Result<Self, WsError> {
let redis_url = redis_url.unwrap_or_else(|| {
std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379/0".to_string())
std::env::var("REDIS_URL").unwrap_or_else(|_| DEFAULT_REDIS_URL.to_string())
});

// Initialize Redis connection pool
Expand All @@ -55,10 +59,18 @@ impl AppState {
// Initialize storage based on feature
#[cfg(feature = "local-storage")]
#[allow(unused_variables)]
let storage = Arc::new(ProjectStorageRepository::new("./local_storage".into()).await?);
let storage =
Arc::new(ProjectStorageRepository::new(DEFAULT_LOCAL_STORAGE_PATH.into()).await?);

#[cfg(feature = "gcs-storage")]
#[cfg(not(feature = "local-storage"))]
let gcs_bucket =
std::env::var("GCS_BUCKET_NAME").expect("GCS_BUCKET_NAME must be provided");

#[cfg(feature = "gcs-storage")]
#[cfg(not(feature = "local-storage"))]
#[allow(unused_variables)]
let storage = Arc::new(ProjectStorageRepository::new("your-gcs-bucket".into()).await?);
let storage = Arc::new(ProjectStorageRepository::new(gcs_bucket).await?);

let session_repo = Arc::new(ProjectRedisRepository::new(redis_pool.clone()));

Expand All @@ -70,12 +82,12 @@ impl AppState {
Arc::new(redis_data_manager),
));

let (tx, rx) = mpsc::channel(32);
let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);

let service_clone = service.clone();
tokio::spawn(async move {
if let Err(e) = service_clone.process(rx).await {
error!("Service processing error: {:?}", e);
error!("Service processing error: {}", e);
}
});

Expand All @@ -85,18 +97,24 @@ impl AppState {
storage,
session_repo,
service,
redis_url,
command_tx: tx,
})
}

// Room related methods
/// Creates a new room with the given ID.
///
/// # Errors
/// Returns `TryLockError` if the rooms mutex is poisoned or locked.
pub fn make_room(&self, room_id: String) -> Result<(), tokio::sync::TryLockError> {
let mut rooms = self.rooms.try_lock()?;
rooms.insert(room_id, Room::new());
Ok(())
}

/// Deletes a room with the given ID.
///
/// # Errors
/// Returns `TryLockError` if the rooms mutex is poisoned or locked.
pub fn delete_room(&self, id: String) -> Result<(), tokio::sync::TryLockError> {
let mut rooms = self.rooms.try_lock()?;
rooms.remove(&id);
Expand Down
40 changes: 15 additions & 25 deletions websocket/crates/infra/src/persistence/editing_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl ProjectEditingSession {
if let Some(snapshot) = snapshot_repo.get_latest_snapshot(&self.project_id).await? {
debug!("Found existing snapshot for project: {}", self.project_id);
redis_manager
.push_update(&self.project_id, snapshot.data, Some(user.name.clone()))
.merge_updates(&self.project_id, snapshot.data, Some(user.id.clone()))
.await?;
} else {
debug!(
Expand All @@ -150,14 +150,16 @@ impl ProjectEditingSession {
pub async fn merge_updates<R>(
&self,
redis_data_manager: &R,
update_data: Vec<u8>,
updated_by: Option<String>,
) -> Result<(Vec<u8>, Vec<String>), ProjectEditingSessionError>
where
R: RedisDataManagerImpl<Error = FlowProjectRedisDataManagerError>,
{
self.check_session_setup()?;
let _guard = self.acquire_lock().await;
redis_data_manager
.merge_updates(&self.project_id, false)
.merge_updates(&self.project_id, update_data.clone(), updated_by)
.await
.map_err(Into::into)
}
Expand All @@ -172,7 +174,7 @@ impl ProjectEditingSession {
self.check_session_setup()?;

let current_state = redis_data_manager
.get_current_state(&self.project_id, self.session_id.as_deref())
.get_current_state(&self.project_id)
.await?;

match current_state {
Expand All @@ -181,23 +183,6 @@ impl ProjectEditingSession {
}
}

pub async fn push_update<R>(
&self,
update: Vec<u8>,
updated_by: String,
redis_data_manager: &R,
) -> Result<(), ProjectEditingSessionError>
where
R: RedisDataManagerImpl<Error = FlowProjectRedisDataManagerError>,
{
self.check_session_setup()?;
let _guard = self.acquire_lock().await;
redis_data_manager
.push_update(&self.project_id, update, Some(updated_by))
.await
.map_err(Into::into)
}

pub async fn create_snapshot<S>(
&self,
user: &User,
Expand All @@ -219,7 +204,7 @@ impl ProjectEditingSession {

let snapshot = ProjectSnapshot::builder()
.project_id(self.project_id.clone())
.created_by(user.name.clone())
.created_by(user.id.clone())
.data(data)
.snapshot_type(SnapshotType::Manual)
.version(version)
Expand Down Expand Up @@ -268,19 +253,24 @@ impl ProjectEditingSession {
{
let _guard = self.acquire_lock().await;

let (state, edits) = redis_data_manager
.merge_updates(&self.project_id, true)
let state = redis_data_manager
.get_current_state(&self.project_id)
.await?;
let edits = redis_data_manager
.get_state_updates_by(&self.project_id)
.await?;

debug!("state: {:?}", state);
debug!("edits: {:?}", edits);

if save_changes {
let snapshot = snapshot_repo.get_latest_snapshot(&self.project_id).await?;
debug!("snapshot: {:?}", snapshot);

if let Some(mut snapshot) = snapshot {
snapshot.data = state;
snapshot.info.changes_by = edits;
snapshot.data = state.unwrap_or_default();
snapshot.info.changes_by =
vec![edits.unwrap_or_else(|| "anonymous".to_string())];
snapshot.metadata.name = Some(snapshot_name);
snapshot_repo.update_latest_snapshot(snapshot).await?;
}
Expand Down
2 changes: 0 additions & 2 deletions websocket/crates/infra/src/persistence/project_repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ pub enum ProjectRepositoryError {
Io(#[from] io::Error),
#[error("Session ID not found")]
SessionIdNotFound,
#[error("{0}")]
Custom(String),
#[error(transparent)]
Redis(#[from] redis::RedisError),
#[error(transparent)]
Expand Down
Loading

0 comments on commit 0793550

Please sign in to comment.