feat: add uid field check, kick old ws connection

Signed-off-by: Pakin <pakin.t@forth.co.th>
This commit is contained in:
Pakin 2026-04-24 17:11:36 +07:00
parent fecdb94841
commit 819bd08bc3

View file

@ -1,5 +1,11 @@
use std::{ use std::{
collections::VecDeque, env, fs::File, io::Read, path::PathBuf, sync::Arc, time::Duration, collections::{HashMap, VecDeque},
env,
fs::File,
io::Read,
path::PathBuf,
sync::Arc,
time::Duration,
}; };
use async_compression::tokio::bufread::BrotliDecoder; use async_compression::tokio::bufread::BrotliDecoder;
@ -9,6 +15,7 @@ use axum::{
State, WebSocketUpgrade, State, WebSocketUpgrade,
ws::{CloseFrame, Message, WebSocket}, ws::{CloseFrame, Message, WebSocket},
}, },
http::HeaderMap,
response::IntoResponse, response::IntoResponse,
routing::{get, post}, routing::{get, post},
serve::ListenerExt, serve::ListenerExt,
@ -135,11 +142,7 @@ async fn invoke_checkout_request(
} }
pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> { pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> {
Router::new() Router::new().route("/ws", get(websocket_handler))
// .route("/", get(get_root_files))
.route("/ws", get(websocket_handler))
// .route("/edit", post())
// .route("/{country}/", method_router)
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -149,6 +152,15 @@ struct SysMessage {
payload: serde_json::Value, payload: serde_json::Value,
} }
/// Request to generate api rotated key
#[derive(Debug, Serialize, Deserialize)]
struct ApiSessionRequest {
// uid from user login
uid: serde_json::Value,
//
timestamp: serde_json::Value,
}
async fn post_from_other_system( async fn post_from_other_system(
State(mut state): State<Arc<AppState>>, State(mut state): State<Arc<AppState>>,
Json(msg): Json<SysMessage>, Json(msg): Json<SysMessage>,
@ -169,22 +181,72 @@ async fn post_from_other_system(
} }
} }
async fn request_api_session_key(
State(mut state): State<Arc<AppState>>,
Json(msg): Json<ApiSessionRequest>,
) -> impl IntoResponse {
let mut rcl = state.redis_cli.clone();
// gen key
let generated = uuid::Uuid::new_v4().to_string().replace("-", "");
if let Some(uid_n) = msg.uid.as_str() {
let saved_key = format!("{uid_n}-ak");
let _ = rcl.set_ex(saved_key, generated.clone(), 3600);
}
(axum::http::StatusCode::OK, generated).into_response()
}
async fn websocket_handler( async fn websocket_handler(
State(mut state): State<Arc<AppState>>, State(mut state): State<Arc<AppState>>,
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
headers: HeaderMap,
) -> impl IntoResponse { ) -> impl IntoResponse {
let state_clone = Arc::clone(&state); let state_clone = Arc::clone(&state);
let hub_clone = Arc::clone(&state_clone.connectors_mapping);
let mut uid_n = String::new();
if let Some(uid) = headers.get("x-auth-uid") {
let uid_from_web = uid.to_str().unwrap_or_default().to_string();
uid_n = uid_from_web;
info!("user connect {uid_n}");
}
if uid_n.is_empty() {
return (axum::http::StatusCode::BAD_REQUEST, "").into_response();
}
ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error)) ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error))
.on_upgrade(async |s| handle_socket(s, state_clone).await.unwrap_or(())) .on_upgrade(async |s| {
handle_socket(s, state_clone, uid_n, hub_clone)
.await
.unwrap_or(())
})
} }
async fn handle_socket( async fn handle_socket(
mut socket: WebSocket, mut socket: WebSocket,
mut state: Arc<AppState>, mut state: Arc<AppState>,
uid: String,
hub: Arc<Mutex<Hub>>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
let (mut sender, mut receiver) = socket.split(); let (mut sender, mut receiver) = socket.split();
// internal channel // internal channel
let (tx, mut rx) = mpsc::channel::<serde_json::Value>(2); let (tx, mut rx) = mpsc::channel::<TxControlMessage>(2);
{
// register & kick old tx
let mut h = hub.try_lock().unwrap();
if let Some(old_tx) = h.clients.insert(uid, tx.clone()) {
warn!("disconnect old connection");
let _ = old_tx.send(TxControlMessage::CloseExist);
}
}
let user_sys_rx = state.system_tx.subscribe(); let user_sys_rx = state.system_tx.subscribe();
let last_seen = Arc::new(Mutex::new(Instant::now())); let last_seen = Arc::new(Mutex::new(Instant::now()));
@ -208,9 +270,9 @@ async fn handle_socket(
if last.elapsed() > TIMEOUT { if last.elapsed() > TIMEOUT {
warn!("Timeout close connection"); warn!("Timeout close connection");
let _ = tx let _ = tx
.send(serde_json::json!({ .send(TxControlMessage::Payload(serde_json::json!({
"timeout": "watchdog" "timeout": "watchdog"
})) })))
.await; .await;
break; break;
} }
@ -331,7 +393,7 @@ async fn read(
// redis: redis::Client, // redis: redis::Client,
mut state: Arc<AppState>, mut state: Arc<AppState>,
mut receiver: SplitStream<WebSocket>, mut receiver: SplitStream<WebSocket>,
tx: Sender<serde_json::Value>, tx: Sender<TxControlMessage>,
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>, mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>, last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
@ -343,7 +405,10 @@ async fn read(
// Send back to client from services // Send back to client from services
while let Ok(s_msg) = system_rx.recv().await { while let Ok(s_msg) = system_rx.recv().await {
if convert_sys_msg_command(&s_msg).is_some() if convert_sys_msg_command(&s_msg).is_some()
&& let Some(err) = tx_to_client.send(s_msg).await.err() && let Some(err) = tx_to_client
.send(TxControlMessage::Payload(s_msg))
.await
.err()
{ {
println!("[SYS] failed to send back to client: {err}"); println!("[SYS] failed to send back to client: {err}");
} }
@ -355,6 +420,7 @@ async fn read(
Message::Text(t) => { Message::Text(t) => {
let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?; let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?;
let req_clone = req.clone(); let req_clone = req.clone();
info!("get msg: {}", req.type_w);
match req.type_w.as_str() { match req.type_w.as_str() {
"recipe" if req.payload.is_some() => { "recipe" if req.payload.is_some() => {
// guard expect value // guard expect value
@ -438,7 +504,7 @@ async fn read(
error!("File corrupted, invalid json format"); error!("File corrupted, invalid json format");
} }
let _ = tx.send(serde_json::json!({ let _ = tx.send(TxControlMessage::Payload(serde_json::json!({
"type": "notify", "type": "notify",
"payload": { "payload": {
"from": "system_tx", "from": "system_tx",
@ -446,7 +512,7 @@ async fn read(
"msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version), "msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version),
"to": "" "to": ""
} }
})).await; }))).await;
return Err(e.into()); return Err(e.into());
} }
@ -549,7 +615,7 @@ async fn read(
if let Some(_) = state.system_tx.send(p).err() { if let Some(_) = state.system_tx.send(p).err() {
info!("failed to send command request"); info!("failed to send command request");
let _ = tx let _ = tx
.send(serde_json::json!({ .send(TxControlMessage::Payload(serde_json::json!({
"type": "notify", "type": "notify",
"payload": { "payload": {
"from": "system_tx", "from": "system_tx",
@ -557,7 +623,7 @@ async fn read(
"msg": "send request fail", "msg": "send request fail",
"to": "" "to": ""
} }
})) })))
.await; .await;
} }
} }
@ -652,9 +718,11 @@ async fn read(
async fn write( async fn write(
mut sender: SplitSink<WebSocket, Message>, mut sender: SplitSink<WebSocket, Message>,
mut rx: Receiver<serde_json::Value>, mut rx: Receiver<TxControlMessage>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
while let Some(res) = rx.recv().await { while let Some(res) = rx.recv().await {
match res {
TxControlMessage::Payload(res) => {
// force close // force close
if let Some(force_timeout_by) = res.get("timeout") if let Some(force_timeout_by) = res.get("timeout")
&& let Some(from_who) = force_timeout_by.as_str() && let Some(from_who) = force_timeout_by.as_str()
@ -665,6 +733,13 @@ async fn write(
break; break;
} }
// if let Some(res_n) = res.as_object()
// && let Some(res_payload) = res_n.get("payload")
// && let Some(res_payload_val) = res_payload.as_object()
// && let Some(recv_ident) = res_payload_val.get("to")
// && let Some(recv_ident_str) = recv_ident.as_str()
// {}
let payload_size = res.to_string().len(); let payload_size = res.to_string().len();
if payload_size >= 100000 { if payload_size >= 100000 {
@ -676,6 +751,12 @@ async fn write(
} }
let _ = sender.send(res.to_string().into()).await; let _ = sender.send(res.to_string().into()).await;
}
TxControlMessage::CloseExist => {
let _ = sender.close().await;
break;
}
}
// リミットブレく - limito breaku!! (uncomment to slow down messages) // リミットブレく - limito breaku!! (uncomment to slow down messages)
// let _ = tokio::time::sleep(Duration::from_millis(100)).await; // let _ = tokio::time::sleep(Duration::from_millis(100)).await;
@ -721,7 +802,7 @@ fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt: i3
async fn throttle_send_recipe( async fn throttle_send_recipe(
recipe: &Recipe, recipe: &Recipe,
tx: &Sender<serde_json::Value>, tx: &Sender<TxControlMessage>,
country: String, country: String,
version: String, version: String,
) { ) {
@ -748,7 +829,7 @@ async fn throttle_send_recipe(
let sid = ss.get_id(); let sid = ss.get_id();
info!("starting {sid}"); info!("starting {sid}");
if let Some(err) = tx.send(ss.as_msg()).await.err() { if let Some(err) = tx.send(TxControlMessage::Payload(ss.as_msg())).await.err() {
println!("ERR: send tx error, {err:?}"); println!("ERR: send tx error, {err:?}");
} }
@ -758,7 +839,7 @@ async fn throttle_send_recipe(
let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec()); let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec());
// no validate // no validate
if let Some(err) = tx.send(sda.as_msg()).await.err() { if let Some(err) = tx.send(TxControlMessage::Payload(sda.as_msg())).await.err() {
println!("ERR: send tx error, {err:?}"); println!("ERR: send tx error, {err:?}");
} }
} }
@ -770,7 +851,11 @@ async fn throttle_send_recipe(
let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec()); let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec());
if let Some(err) = tx.send(extra_matset.as_msg()).await.err() { if let Some(err) = tx
.send(TxControlMessage::Payload(extra_matset.as_msg()))
.await
.err()
{
println!("ERR: send tx extra error: {err:?}"); println!("ERR: send tx extra error: {err:?}");
} }
} }
@ -779,7 +864,11 @@ async fn throttle_send_recipe(
for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() { for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() {
let curr_ch_id = format!("{mat_exid}_tl{index}"); let curr_ch_id = format!("{mat_exid}_tl{index}");
let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec()); let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec());
if let Some(err) = tx.send(extra_topplist.as_msg()).await.err() { if let Some(err) = tx
.send(TxControlMessage::Payload(extra_topplist.as_msg()))
.await
.err()
{
println!("ERR: send tx extra2 error: {err:?}"); println!("ERR: send tx extra2 error: {err:?}");
} }
} }
@ -788,7 +877,11 @@ async fn throttle_send_recipe(
for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() { for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() {
let curr_ch_id = format!("{mat_exid}_tg{index}"); let curr_ch_id = format!("{mat_exid}_tg{index}");
let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec()); let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec());
if let Some(err) = tx.send(extra_toppgrp.as_msg()).await.err() { if let Some(err) = tx
.send(TxControlMessage::Payload(extra_toppgrp.as_msg()))
.await
.err()
{
println!("ERR: send tx extra2 error: {err:?}"); println!("ERR: send tx extra2 error: {err:?}");
} }
} }
@ -803,15 +896,32 @@ async fn throttle_send_recipe(
// return sid; // return sid;
let end_msg = StreamDataEnd::new(&sid); let end_msg = StreamDataEnd::new(&sid);
if let Some(err) = tx.send(end_msg.as_msg()).await.err() { if let Some(err) = tx
.send(TxControlMessage::Payload(end_msg.as_msg()))
.await
.err()
{
println!("ERR: send tx error, {err:?}"); println!("ERR: send tx error, {err:?}");
} }
} }
#[derive(Clone)]
pub struct Hub {
pub clients: HashMap<String, Sender<TxControlMessage>>,
}
#[derive(Clone)]
pub enum TxControlMessage {
Payload(serde_json::Value),
CloseExist,
}
pub struct AppState { pub struct AppState {
dev_config: DevConfig, dev_config: DevConfig,
redis_cli: redis::Client, redis_cli: redis::Client,
system_tx: tokio::sync::broadcast::Sender<serde_json::Value>, system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
// saved client uid:client uuid
connectors_mapping: Arc<Mutex<Hub>>,
} }
impl AppState { impl AppState {
@ -831,6 +941,9 @@ impl AppState {
dev_config, dev_config,
redis_cli, redis_cli,
system_tx, system_tx,
connectors_mapping: Arc::new(Mutex::new(Hub {
clients: HashMap::new(),
})),
}); });
tokio::spawn(async move { tokio::spawn(async move {
@ -988,6 +1101,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new() let app = Router::new()
// .route("/sessionLogin", post(session_login)) // .route("/sessionLogin", post(session_login))
.route("/syscb", post(post_from_other_system)) .route("/syscb", post(post_from_other_system))
.route("/regas", post(request_api_session_key))
.nest("/recipe", rp_router) .nest("/recipe", rp_router)
.nest("/docs", doc_router) .nest("/docs", doc_router)
.with_state(app_state); .with_state(app_state);