From 819bd08bc3ad87141d1a45f01e09bdd4a6d0c5b4 Mon Sep 17 00:00:00 2001 From: Pakin Date: Fri, 24 Apr 2026 17:11:36 +0700 Subject: [PATCH] feat: add uid field check, kick old ws connection Signed-off-by: Pakin --- src/main.rs | 202 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 158 insertions(+), 44 deletions(-) diff --git a/src/main.rs b/src/main.rs index f871596..feee590 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,11 @@ use std::{ - collections::VecDeque, env, fs::File, io::Read, path::PathBuf, sync::Arc, time::Duration, + collections::{HashMap, VecDeque}, + env, + fs::File, + io::Read, + path::PathBuf, + sync::Arc, + time::Duration, }; use async_compression::tokio::bufread::BrotliDecoder; @@ -9,6 +15,7 @@ use axum::{ State, WebSocketUpgrade, ws::{CloseFrame, Message, WebSocket}, }, + http::HeaderMap, response::IntoResponse, routing::{get, post}, serve::ListenerExt, @@ -135,11 +142,7 @@ async fn invoke_checkout_request( } pub async fn create_recipe_repo_router() -> Router> { - Router::new() - // .route("/", get(get_root_files)) - .route("/ws", get(websocket_handler)) - // .route("/edit", post()) - // .route("/{country}/", method_router) + Router::new().route("/ws", get(websocket_handler)) } #[derive(Debug, Serialize, Deserialize)] @@ -149,6 +152,15 @@ struct SysMessage { payload: serde_json::Value, } +/// Request to generate api rotated key +#[derive(Debug, Serialize, Deserialize)] +struct ApiSessionRequest { + // uid from user login + uid: serde_json::Value, + // + timestamp: serde_json::Value, +} + async fn post_from_other_system( State(mut state): State>, Json(msg): Json, @@ -169,22 +181,72 @@ async fn post_from_other_system( } } +async fn request_api_session_key( + State(mut state): State>, + Json(msg): Json, +) -> impl IntoResponse { + let mut rcl = state.redis_cli.clone(); + + // gen key + let generated = uuid::Uuid::new_v4().to_string().replace("-", ""); + + if let Some(uid_n) = msg.uid.as_str() { + let saved_key = format!("{uid_n}-ak"); + let _ = rcl.set_ex(saved_key, generated.clone(), 3600); + } + + (axum::http::StatusCode::OK, generated).into_response() +} + async fn websocket_handler( State(mut state): State>, ws: WebSocketUpgrade, + headers: HeaderMap, ) -> impl IntoResponse { let state_clone = Arc::clone(&state); + let hub_clone = Arc::clone(&state_clone.connectors_mapping); + + let mut uid_n = String::new(); + + if let Some(uid) = headers.get("x-auth-uid") { + let uid_from_web = uid.to_str().unwrap_or_default().to_string(); + + uid_n = uid_from_web; + info!("user connect {uid_n}"); + } + + if uid_n.is_empty() { + return (axum::http::StatusCode::BAD_REQUEST, "").into_response(); + } + ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error)) - .on_upgrade(async |s| handle_socket(s, state_clone).await.unwrap_or(())) + .on_upgrade(async |s| { + handle_socket(s, state_clone, uid_n, hub_clone) + .await + .unwrap_or(()) + }) } async fn handle_socket( mut socket: WebSocket, mut state: Arc, + uid: String, + hub: Arc>, ) -> Result<(), Box> { let (mut sender, mut receiver) = socket.split(); // internal channel - let (tx, mut rx) = mpsc::channel::(2); + let (tx, mut rx) = mpsc::channel::(2); + + { + // register & kick old tx + + let mut h = hub.try_lock().unwrap(); + if let Some(old_tx) = h.clients.insert(uid, tx.clone()) { + warn!("disconnect old connection"); + let _ = old_tx.send(TxControlMessage::CloseExist); + } + } + let user_sys_rx = state.system_tx.subscribe(); let last_seen = Arc::new(Mutex::new(Instant::now())); @@ -208,9 +270,9 @@ async fn handle_socket( if last.elapsed() > TIMEOUT { warn!("Timeout close connection"); let _ = tx - .send(serde_json::json!({ + .send(TxControlMessage::Payload(serde_json::json!({ "timeout": "watchdog" - })) + }))) .await; break; } @@ -331,7 +393,7 @@ async fn read( // redis: redis::Client, mut state: Arc, mut receiver: SplitStream, - tx: Sender, + tx: Sender, mut system_rx: tokio::sync::broadcast::Receiver, last_seen: Arc>, // cmd_atom: crossbeam_queue::ArrayQueue, ) -> Result<(), Box> { @@ -343,7 +405,10 @@ async fn read( // Send back to client from services while let Ok(s_msg) = system_rx.recv().await { if convert_sys_msg_command(&s_msg).is_some() - && let Some(err) = tx_to_client.send(s_msg).await.err() + && let Some(err) = tx_to_client + .send(TxControlMessage::Payload(s_msg)) + .await + .err() { println!("[SYS] failed to send back to client: {err}"); } @@ -355,6 +420,7 @@ async fn read( Message::Text(t) => { let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?; let req_clone = req.clone(); + info!("get msg: {}", req.type_w); match req.type_w.as_str() { "recipe" if req.payload.is_some() => { // guard expect value @@ -438,7 +504,7 @@ async fn read( error!("File corrupted, invalid json format"); } - let _ = tx.send(serde_json::json!({ + let _ = tx.send(TxControlMessage::Payload(serde_json::json!({ "type": "notify", "payload": { "from": "system_tx", @@ -446,7 +512,7 @@ async fn read( "msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version), "to": "" } - })).await; + }))).await; return Err(e.into()); } @@ -549,7 +615,7 @@ async fn read( if let Some(_) = state.system_tx.send(p).err() { info!("failed to send command request"); let _ = tx - .send(serde_json::json!({ + .send(TxControlMessage::Payload(serde_json::json!({ "type": "notify", "payload": { "from": "system_tx", @@ -557,7 +623,7 @@ async fn read( "msg": "send request fail", "to": "" } - })) + }))) .await; } } @@ -652,31 +718,46 @@ async fn read( async fn write( mut sender: SplitSink, - mut rx: Receiver, + mut rx: Receiver, ) -> Result<(), Box> { while let Some(res) = rx.recv().await { - // force close - if let Some(force_timeout_by) = res.get("timeout") - && let Some(from_who) = force_timeout_by.as_str() - && from_who.eq("watchdog") - { - warn!("receive close from watchdog"); - let _ = sender.send(Message::Close(None)).await; - break; + match res { + TxControlMessage::Payload(res) => { + // force close + if let Some(force_timeout_by) = res.get("timeout") + && let Some(from_who) = force_timeout_by.as_str() + && from_who.eq("watchdog") + { + warn!("receive close from watchdog"); + let _ = sender.send(Message::Close(None)).await; + break; + } + + // if let Some(res_n) = res.as_object() + // && let Some(res_payload) = res_n.get("payload") + // && let Some(res_payload_val) = res_payload.as_object() + // && let Some(recv_ident) = res_payload_val.get("to") + // && let Some(recv_ident_str) = recv_ident.as_str() + // {} + + let payload_size = res.to_string().len(); + + if payload_size >= 100000 { + // large payload + warn!( + "sending large payload to client ... ({})", + res.to_string().len() + ); + } + + let _ = sender.send(res.to_string().into()).await; + } + TxControlMessage::CloseExist => { + let _ = sender.close().await; + break; + } } - let payload_size = res.to_string().len(); - - if payload_size >= 100000 { - // large payload - warn!( - "sending large payload to client ... ({})", - res.to_string().len() - ); - } - - let _ = sender.send(res.to_string().into()).await; - // リミットブレく - limito breaku!! (uncomment to slow down messages) // let _ = tokio::time::sleep(Duration::from_millis(100)).await; } @@ -721,7 +802,7 @@ fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt: i3 async fn throttle_send_recipe( recipe: &Recipe, - tx: &Sender, + tx: &Sender, country: String, version: String, ) { @@ -748,7 +829,7 @@ async fn throttle_send_recipe( let sid = ss.get_id(); info!("starting {sid}"); - if let Some(err) = tx.send(ss.as_msg()).await.err() { + if let Some(err) = tx.send(TxControlMessage::Payload(ss.as_msg())).await.err() { println!("ERR: send tx error, {err:?}"); } @@ -758,7 +839,7 @@ async fn throttle_send_recipe( let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec()); // no validate - if let Some(err) = tx.send(sda.as_msg()).await.err() { + if let Some(err) = tx.send(TxControlMessage::Payload(sda.as_msg())).await.err() { println!("ERR: send tx error, {err:?}"); } } @@ -770,7 +851,11 @@ async fn throttle_send_recipe( let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec()); - if let Some(err) = tx.send(extra_matset.as_msg()).await.err() { + if let Some(err) = tx + .send(TxControlMessage::Payload(extra_matset.as_msg())) + .await + .err() + { println!("ERR: send tx extra error: {err:?}"); } } @@ -779,7 +864,11 @@ async fn throttle_send_recipe( for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() { let curr_ch_id = format!("{mat_exid}_tl{index}"); let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec()); - if let Some(err) = tx.send(extra_topplist.as_msg()).await.err() { + if let Some(err) = tx + .send(TxControlMessage::Payload(extra_topplist.as_msg())) + .await + .err() + { println!("ERR: send tx extra2 error: {err:?}"); } } @@ -788,7 +877,11 @@ async fn throttle_send_recipe( for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() { let curr_ch_id = format!("{mat_exid}_tg{index}"); let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec()); - if let Some(err) = tx.send(extra_toppgrp.as_msg()).await.err() { + if let Some(err) = tx + .send(TxControlMessage::Payload(extra_toppgrp.as_msg())) + .await + .err() + { println!("ERR: send tx extra2 error: {err:?}"); } } @@ -803,15 +896,32 @@ async fn throttle_send_recipe( // return sid; let end_msg = StreamDataEnd::new(&sid); - if let Some(err) = tx.send(end_msg.as_msg()).await.err() { + if let Some(err) = tx + .send(TxControlMessage::Payload(end_msg.as_msg())) + .await + .err() + { println!("ERR: send tx error, {err:?}"); } } +#[derive(Clone)] +pub struct Hub { + pub clients: HashMap>, +} + +#[derive(Clone)] +pub enum TxControlMessage { + Payload(serde_json::Value), + CloseExist, +} + pub struct AppState { dev_config: DevConfig, redis_cli: redis::Client, system_tx: tokio::sync::broadcast::Sender, + // saved client uid:client uuid + connectors_mapping: Arc>, } impl AppState { @@ -831,6 +941,9 @@ impl AppState { dev_config, redis_cli, system_tx, + connectors_mapping: Arc::new(Mutex::new(Hub { + clients: HashMap::new(), + })), }); tokio::spawn(async move { @@ -988,6 +1101,7 @@ async fn main() -> Result<(), Box> { let app = Router::new() // .route("/sessionLogin", post(session_login)) .route("/syscb", post(post_from_other_system)) + .route("/regas", post(request_api_session_key)) .nest("/recipe", rp_router) .nest("/docs", doc_router) .with_state(app_state);