From fca6c54d76180e2793f64ddc1794f4508bf6a4e8 Mon Sep 17 00:00:00 2001 From: Marc Planard Date: Mon, 7 Aug 2023 20:28:11 +0200 Subject: [PATCH] initial commit --- Cargo.toml | 18 ++++ src/main.rs | 278 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+) create mode 100644 Cargo.toml create mode 100644 src/main.rs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..80909de --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "lj_sketch" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +axum = { version = "0.6.20", features = ["ws", "headers"] } +futures = "0.3.28" +geo = "0.26.0" +serde = { version = "1.0.182", features = ["derive"] } +serde_json = "1.0.104" +tokio = { version = "1.29.1", features = ["macros", "rt-multi-thread"] } +tower = { version = "0.4", features = ["util"] } +tower-http = { version = "0.4.0", features = ["fs", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c91643a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,278 @@ +use axum::{ + extract::{ + ws::{ Message, Message::Text, Message::Close, + WebSocket, WebSocketUpgrade}, + TypedHeader, + }, + response::IntoResponse, + routing::get, + Router, + Extension +}; + +use std::{net::SocketAddr, path::PathBuf}; +use tower_http::{ + services::ServeDir, + trace::{DefaultMakeSpan, TraceLayer}, +}; + +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +//allows to extract the IP of connecting user +use axum::extract::{ + connect_info::ConnectInfo, + //ws::CloseFrame +}; + +//allows to split the websocket stream into separate TX and RX branches +use futures::sink::SinkExt; +use futures::stream::{SplitSink,StreamExt}; +use std::sync::Arc; +use tokio::sync::{ + mpsc:: { self, Sender, Receiver }, + Mutex +}; +use serde::{Serialize,Deserialize}; +use geo::Simplify; + +use std::collections::HashMap; + +const LISTEN_ON : &str = "0.0.0.0:3000"; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "t")] +enum JMsg { + #[serde(rename = "clear")] + Clear, + #[serde(rename = "moveTo")] + MoveTo { x: f32, y: f32, color: String }, + #[serde(rename = "lineTo")] + LineTo { x: f32, y: f32, color: String }, + #[serde(rename = "stroke")] + Stroke, + #[serde(rename = "line")] + Line { line: Vec<(f32,f32,String)> } +} + +type Line = Vec<(f32,f32,u32)>; + +#[derive(Debug)] +enum GSMsg { + NewClient((SocketAddr,SplitSink)), + NewLine(Line), + DeleteClient(SocketAddr), + Clear +} + + +struct State { + gs_tx: Sender +} + +async fn gen_server(mut rx: Receiver) { + let mut clients : HashMap> = + HashMap::new(); + + let mut lines : Vec = vec![]; + + while let Some(msg) = rx.recv().await { + match msg { + GSMsg::NewClient((addr, mut tx)) => { + for line in &lines { + tx + .send(Message::Text(line_to_json(&line))) + .await + .unwrap(); + } + clients.insert(addr, tx); + tracing::info!("NewClient {addr}"); + }, + GSMsg::NewLine(line) => { + let msg = line_to_json(&line); + send_all(&mut clients, msg).await; + lines.push(line); + }, + GSMsg::DeleteClient(addr) => { + tracing::info!("Client {addr} removed"); + clients.remove(&addr); + }, + GSMsg::Clear => { + let msg = serde_json::to_string(&JMsg::Clear).unwrap(); + send_all(&mut clients, msg).await; + lines.clear(); + } + } + } +} + +async fn send_all( + clients: &mut HashMap>, + msg: String +) { + let mut to_remove : Vec = vec![]; + + for (addr, ref mut tx) in &mut *clients { + let ret = tx + .send(Message::Text(msg.clone())) + .await; + if ret.is_err() { + tracing::warn!("Client {addr} abruptly disconnected"); + to_remove.push(*addr); + } + } + + for addr in to_remove { + clients.remove(&addr); + } +} + +fn line_to_json(line: &Line) -> String { + let line = line.iter() + .map(| (x, y, c) | { + (*x, *y, format!("#{:06x}", c)) + }) + .collect(); + serde_json::to_string(&JMsg::Line{ line }).unwrap() +} + + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "lj_sketch=info,tower_http=info" + .into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); + + let (tx, rx) : (Sender, Receiver) = mpsc::channel(32); + + let state = Arc::new(Mutex::new(State { + gs_tx: tx + })); + + tokio::spawn(gen_server(rx)); + + // build our application with some routes + let app = Router::new() + .fallback_service(ServeDir::new(assets_dir) + .append_index_html_on_directories(true)) + .route("/ws", get(ws_handler)) + + .layer(Extension(state)) + // logging so we can see whats going on + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::default() + .include_headers(false)), + ); + + + let addr : SocketAddr = LISTEN_ON.parse().unwrap(); + + tracing::info!("listening on {}", addr); + axum::Server::bind(&addr) + .serve(app.into_make_service_with_connect_info::()) + .await + .unwrap(); +} + +async fn ws_handler( + ws: WebSocketUpgrade, + Extension(state): Extension>>, + user_agent: Option>, + ConnectInfo(addr): ConnectInfo, +) -> impl IntoResponse { + let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { + user_agent.to_string() + } else { + String::from("Unknown browser") + }; + tracing::info!("`{user_agent}` at {addr} connected."); + // finalize the upgrade process by returning upgrade callback. + // we can customize the callback by sending additional info such as address. + ws.on_upgrade(move |socket| handle_socket(socket, addr, state)) +} + +async fn handle_socket( + socket: WebSocket, + who: SocketAddr, + state: Arc> +) { + let (tx, mut rx) = socket.split(); + { + let st = state.lock().await; + (*st).gs_tx.send(GSMsg::NewClient((who.clone(), tx))).await.unwrap(); + } + let mut line : Line = vec![]; + + while let Some(msg) = rx.next().await { + match msg { + Ok(Text(msg)) => { + let Ok(msg) : Result = serde_json::from_str(&msg) else { + tracing::warn!("{who}: Can't parse JSON: {:?}", msg); + continue; + }; + tracing::debug!("{who}: '{:?}'", msg); + match msg { + JMsg::Clear => { + let st = state.lock().await; + (*st).gs_tx.send(GSMsg::Clear) + .await.unwrap(); + line.clear(); + }, + JMsg::MoveTo { x, y, color } => { + line = vec![ (x, y, parse_color(color)) ]; + }, + JMsg::LineTo { x, y, color } => { + line.push( (x, y, parse_color(color)) ); + }, + JMsg::Stroke => { + if line.len() > 1 { + let line = simplify_line(&line); + + let st = state.lock().await; + (*st).gs_tx.send(GSMsg::NewLine(line)) + .await.unwrap(); + } + line = vec![]; + }, + JMsg::Line{..} => { panic!("recieved a line message :/"); } + } + }, + Ok(Close(close)) => { + tracing::info!("{who}: closing: {:?}", close); + let st = state.lock().await; + (*st).gs_tx.send(GSMsg::DeleteClient(who)) + .await.unwrap(); + break; + }, + _ => { + tracing::warn!("{who}: Can't handle message: {:?}", msg); + } + } + } +} + +fn simplify_line(line: &Line) -> Line { + if line.len() < 2 { + return line.to_vec(); + } + let color = line[0].2; + let linestring : geo::LineString = + line.iter() + .map(| (x, y, _) | (*x as f64, *y as f64 )) + .collect(); + let linestring = linestring.simplify(&4.0); + linestring.0.iter() + .map(| c | (c.x as f32, c.y as f32, color)) + .collect() +} + +fn parse_color(s: String) -> u32 { + u32::from_str_radix(&s[1..], 16).unwrap() +}