diff --git a/src/gen_server.rs b/src/gen_server.rs new file mode 100644 index 0000000..ff72696 --- /dev/null +++ b/src/gen_server.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use tokio::sync::mpsc:: {Sender, Receiver}; +use crate::ws_client::Line; + +#[derive(Debug, Clone)] +pub enum GSMsg { + NewClient((SocketAddr, Sender)), + NewLine(Line), + DeleteClient(SocketAddr), + Clear +} + +pub struct State { + pub gs_tx: Sender +} + +pub async fn gen_server(mut rx: Receiver) { + let mut clients : HashMap> = + HashMap::new(); + + let mut lines : Vec = vec![]; + + while let Some(msg) = rx.recv().await { + match msg { + GSMsg::NewClient((addr, c_tx)) => { + for line in &lines { + c_tx.send(GSMsg::NewLine(line.clone())) + .await.unwrap(); + } + clients.insert(addr, c_tx); + tracing::info!("NewClient {addr}"); + }, + GSMsg::NewLine(line) => { + send_all(&mut clients, &GSMsg::NewLine(line.clone())).await; + lines.push(line); + }, + GSMsg::DeleteClient(addr) => { + tracing::info!("Client {addr} removed"); + clients.remove(&addr); + }, + GSMsg::Clear => { + send_all(&mut clients, &GSMsg::Clear).await; + lines.clear(); + } + } + } +} + +async fn send_all( + clients: &mut HashMap>, + msg: &GSMsg +) { + let mut to_remove : Vec = vec![]; + + for (addr, ref mut tx) in &mut *clients { + let ret = tx + .send(msg.clone()) + .await; + if ret.is_err() { + tracing::warn!("Client {addr} abruptly disconnected"); + to_remove.push(*addr); + } + } + + for addr in to_remove { + clients.remove(&addr); + } +} diff --git a/src/main.rs b/src/main.rs index c91643a..068f1c7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,9 @@ +mod gen_server; +mod ws_client; + use axum::{ extract::{ - ws::{ Message, Message::Text, Message::Close, - WebSocket, WebSocketUpgrade}, + ws::WebSocketUpgrade, TypedHeader, }, response::IntoResponse, @@ -9,133 +11,23 @@ use axum::{ Router, Extension }; - use std::{net::SocketAddr, path::PathBuf}; use tower_http::{ services::ServeDir, trace::{DefaultMakeSpan, TraceLayer}, }; - use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use axum::extract::connect_info::ConnectInfo; -//allows to extract the IP of connecting user -use axum::extract::{ - connect_info::ConnectInfo, - //ws::CloseFrame -}; - -//allows to split the websocket stream into separate TX and RX branches -use futures::sink::SinkExt; -use futures::stream::{SplitSink,StreamExt}; use std::sync::Arc; use tokio::sync::{ mpsc:: { self, Sender, Receiver }, Mutex }; -use serde::{Serialize,Deserialize}; -use geo::Simplify; - -use std::collections::HashMap; +use gen_server::{State,GSMsg,gen_server}; const LISTEN_ON : &str = "0.0.0.0:3000"; -#[derive(Serialize, Deserialize, Debug)] -#[serde(tag = "t")] -enum JMsg { - #[serde(rename = "clear")] - Clear, - #[serde(rename = "moveTo")] - MoveTo { x: f32, y: f32, color: String }, - #[serde(rename = "lineTo")] - LineTo { x: f32, y: f32, color: String }, - #[serde(rename = "stroke")] - Stroke, - #[serde(rename = "line")] - Line { line: Vec<(f32,f32,String)> } -} - -type Line = Vec<(f32,f32,u32)>; - -#[derive(Debug)] -enum GSMsg { - NewClient((SocketAddr,SplitSink)), - NewLine(Line), - DeleteClient(SocketAddr), - Clear -} - - -struct State { - gs_tx: Sender -} - -async fn gen_server(mut rx: Receiver) { - let mut clients : HashMap> = - HashMap::new(); - - let mut lines : Vec = vec![]; - - while let Some(msg) = rx.recv().await { - match msg { - GSMsg::NewClient((addr, mut tx)) => { - for line in &lines { - tx - .send(Message::Text(line_to_json(&line))) - .await - .unwrap(); - } - clients.insert(addr, tx); - tracing::info!("NewClient {addr}"); - }, - GSMsg::NewLine(line) => { - let msg = line_to_json(&line); - send_all(&mut clients, msg).await; - lines.push(line); - }, - GSMsg::DeleteClient(addr) => { - tracing::info!("Client {addr} removed"); - clients.remove(&addr); - }, - GSMsg::Clear => { - let msg = serde_json::to_string(&JMsg::Clear).unwrap(); - send_all(&mut clients, msg).await; - lines.clear(); - } - } - } -} - -async fn send_all( - clients: &mut HashMap>, - msg: String -) { - let mut to_remove : Vec = vec![]; - - for (addr, ref mut tx) in &mut *clients { - let ret = tx - .send(Message::Text(msg.clone())) - .await; - if ret.is_err() { - tracing::warn!("Client {addr} abruptly disconnected"); - to_remove.push(*addr); - } - } - - for addr in to_remove { - clients.remove(&addr); - } -} - -fn line_to_json(line: &Line) -> String { - let line = line.iter() - .map(| (x, y, c) | { - (*x, *y, format!("#{:06x}", c)) - }) - .collect(); - serde_json::to_string(&JMsg::Line{ line }).unwrap() -} - - #[tokio::main] async fn main() { tracing_subscriber::registry() @@ -149,29 +41,23 @@ async fn main() { let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); - let (tx, rx) : (Sender, Receiver) = mpsc::channel(32); + let (gs_tx, gs_rx) : (Sender, Receiver) = mpsc::channel(32); - let state = Arc::new(Mutex::new(State { - gs_tx: tx - })); + let state = Arc::new(Mutex::new(State { gs_tx })); - tokio::spawn(gen_server(rx)); + tokio::spawn(gen_server(gs_rx)); - // build our application with some routes let app = Router::new() .fallback_service(ServeDir::new(assets_dir) .append_index_html_on_directories(true)) - .route("/ws", get(ws_handler)) - + .route("/ws", get(ws_handler)) .layer(Extension(state)) - // logging so we can see whats going on .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::default() .include_headers(false)), ); - - + let addr : SocketAddr = LISTEN_ON.parse().unwrap(); tracing::info!("listening on {}", addr); @@ -195,84 +81,6 @@ async fn ws_handler( tracing::info!("`{user_agent}` at {addr} connected."); // finalize the upgrade process by returning upgrade callback. // we can customize the callback by sending additional info such as address. - ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) + ws.on_upgrade(move |socket| ws_client::handle_socket(socket, addr, state)) } -async fn handle_socket( - socket: WebSocket, - who: SocketAddr, - state: Arc> -) { - let (tx, mut rx) = socket.split(); - { - let st = state.lock().await; - (*st).gs_tx.send(GSMsg::NewClient((who.clone(), tx))).await.unwrap(); - } - let mut line : Line = vec![]; - - while let Some(msg) = rx.next().await { - match msg { - Ok(Text(msg)) => { - let Ok(msg) : Result = serde_json::from_str(&msg) else { - tracing::warn!("{who}: Can't parse JSON: {:?}", msg); - continue; - }; - tracing::debug!("{who}: '{:?}'", msg); - match msg { - JMsg::Clear => { - let st = state.lock().await; - (*st).gs_tx.send(GSMsg::Clear) - .await.unwrap(); - line.clear(); - }, - JMsg::MoveTo { x, y, color } => { - line = vec![ (x, y, parse_color(color)) ]; - }, - JMsg::LineTo { x, y, color } => { - line.push( (x, y, parse_color(color)) ); - }, - JMsg::Stroke => { - if line.len() > 1 { - let line = simplify_line(&line); - - let st = state.lock().await; - (*st).gs_tx.send(GSMsg::NewLine(line)) - .await.unwrap(); - } - line = vec![]; - }, - JMsg::Line{..} => { panic!("recieved a line message :/"); } - } - }, - Ok(Close(close)) => { - tracing::info!("{who}: closing: {:?}", close); - let st = state.lock().await; - (*st).gs_tx.send(GSMsg::DeleteClient(who)) - .await.unwrap(); - break; - }, - _ => { - tracing::warn!("{who}: Can't handle message: {:?}", msg); - } - } - } -} - -fn simplify_line(line: &Line) -> Line { - if line.len() < 2 { - return line.to_vec(); - } - let color = line[0].2; - let linestring : geo::LineString = - line.iter() - .map(| (x, y, _) | (*x as f64, *y as f64 )) - .collect(); - let linestring = linestring.simplify(&4.0); - linestring.0.iter() - .map(| c | (c.x as f32, c.y as f32, color)) - .collect() -} - -fn parse_color(s: String) -> u32 { - u32::from_str_radix(&s[1..], 16).unwrap() -} diff --git a/src/ws_client.rs b/src/ws_client.rs new file mode 100644 index 0000000..fb663fb --- /dev/null +++ b/src/ws_client.rs @@ -0,0 +1,161 @@ +use crate::gen_server::{State,GSMsg}; + +use axum::extract::ws::{ Message, Message::Text, Message::Close, WebSocket }; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{ + mpsc:: { self, Sender, Receiver }, + Mutex +}; +use serde::{Serialize,Deserialize}; +use geo::Simplify; + +use core::ops::ControlFlow; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "t")] +pub enum JMsg { + #[serde(rename = "clear")] + Clear, + #[serde(rename = "moveTo")] + MoveTo { x: f32, y: f32, color: String }, + #[serde(rename = "lineTo")] + LineTo { x: f32, y: f32, color: String }, + #[serde(rename = "stroke")] + Stroke, + #[serde(rename = "line")] + Line { line: Vec<(f32,f32,String)> } +} + +pub type Line = Vec<(f32,f32,u32)>; + +pub async fn handle_socket( + mut socket: WebSocket, + who: SocketAddr, + state: Arc> +) { + + let (c_tx, mut c_rx) : (Sender, Receiver) = mpsc::channel(32); + + { + state.lock() + .await + .gs_tx.send(GSMsg::NewClient((who, c_tx))) + .await.unwrap(); + } + let mut line : Line = vec![]; + + loop { + tokio::select! { + Some(msg) = socket.recv() => { + match process_ws_msg(&state, &who, &mut line, msg).await { + ControlFlow::Break(()) => { return; }, + ControlFlow::Continue(()) => {} + } + }, + Some(msg) = c_rx.recv() => { + match msg { + GSMsg::NewLine(line) => { + socket.send(Message::Text(line_to_json(&line))) + .await.unwrap(); + }, + GSMsg::Clear => { + let msg = serde_json::to_string(&JMsg::Clear).unwrap(); + socket.send(Message::Text(msg)) + .await.unwrap(); + }, + msg => { + tracing::info!("{who} should not get this: {:?}", msg) + } + } + }, + else => { + tracing::warn!("{who}: Connection lost unexpectedly."); + return; + } + } + } +} + +async fn process_ws_msg( + state: &Arc>, + who: &SocketAddr, + line: &mut Line, + msg: Result +) -> ControlFlow<(),()> { + match msg { + Ok(Text(msg)) => { + let Ok(msg) : Result = serde_json::from_str(&msg) else { + tracing::warn!("{who}: Can't parse JSON: {:?}", msg); + return ControlFlow::Continue(()); + }; + tracing::debug!("{who}: '{:?}'", msg); + match msg { + JMsg::Clear => { + state.lock() + .await + .gs_tx.send(GSMsg::Clear) + .await.unwrap(); + line.clear(); + }, + JMsg::MoveTo { x, y, color } => { + *line = vec![ (x, y, parse_color(color)) ]; + }, + JMsg::LineTo { x, y, color } => { + line.push( (x, y, parse_color(color)) ); + }, + JMsg::Stroke => { + if line.len() > 1 { + state.lock() + .await + .gs_tx.send(GSMsg::NewLine(simplify_line(line))) + .await.unwrap(); + } + *line = vec![]; + }, + JMsg::Line{..} => { panic!("recieved a line message :/"); } + } + }, + Ok(Close(close)) => { + tracing::info!("{who}: closing: {:?}", close); + state.lock() + .await + .gs_tx.send(GSMsg::DeleteClient(*who)) + .await.unwrap(); + return ControlFlow::Break(()); + }, + _ => { + tracing::warn!("{who}: Can't handle message: {:?}", msg); + } + } + ControlFlow::Continue(()) +} + +fn simplify_line(line: &Line) -> Line { + if line.len() < 2 { + return line.to_vec(); + } + let color = line[0].2; + let linestring : geo::LineString = + line.iter() + .map(| (x, y, _) | (*x as f64, *y as f64 )) + .collect(); + let linestring = linestring.simplify(&4.0); + linestring.0.iter() + .map(| c | (c.x as f32, c.y as f32, color)) + .collect() +} + + +fn line_to_json(line: &Line) -> String { + let line = line.iter() + .map(| (x, y, c) | { + (*x, *y, format!("#{:06x}", c)) + }) + .collect(); + serde_json::to_string(&JMsg::Line{ line }).unwrap() +} + +fn parse_color(s: String) -> u32 { + u32::from_str_radix(&s[1..], 16).unwrap() +}