diff --git a/src/app.rs b/src/app.rs index 374fcb0..d84ef74 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,24 +1,20 @@ use crate::websocket::{core::*, helper::read_sheet_config, model::*}; use axum::{ Router, + extract::DefaultBodyLimit, routing::{get, post}, serve::ListenerExt, }; -use log::{error, info, warn}; +use log::{error, info}; use redis::TypedCommands; use reqwest::{StatusCode, multipart}; use std::{ collections::{HashMap, VecDeque}, env, - fs::{self, File}, - io::BufReader, - sync::Arc, + sync::{Arc, RwLock}, time::Duration, }; -use tokio::{ - fs::read_dir, - sync::{Mutex, mpsc::Sender}, -}; +use tokio::sync::{Mutex, mpsc::Sender}; #[derive(Clone)] pub struct Hub { @@ -106,7 +102,7 @@ pub struct AppState { pub redis_cli: redis::Client, pub system_tx: tokio::sync::broadcast::Sender, // saved client uid:client uuid - pub connectors_mapping: Arc>, + pub connectors_mapping: Arc>, } impl AppState { @@ -126,7 +122,7 @@ impl AppState { dev_config: dev_config.clone(), redis_cli, system_tx, - connectors_mapping: Arc::new(Mutex::new(Hub { + connectors_mapping: Arc::new(RwLock::new(Hub { clients: HashMap::new(), })), }); @@ -439,10 +435,12 @@ pub async fn initialize() -> Result<(), Box> { "/syscb", post(crate::websocket::handler::post_from_other_system), ) + .route("/users", get(crate::websocket::handler::get_online_users)) .route("/load-config", post(crate::websocket::handler::post_config)) // .route("/regas", post(request_api_session_key)) .nest("/recipe", rp_router) // .nest("/docs", doc_router) + .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) .with_state(app_state); // feature: no delay, full throttle diff --git a/src/websocket/core.rs b/src/websocket/core.rs index ec8c8fa..b15f4c0 100644 --- a/src/websocket/core.rs +++ b/src/websocket/core.rs @@ -13,6 +13,9 @@ pub const TIMEOUT: Duration = Duration::from_secs(60 * 5); /// CONFIG: date format for using in recipe pub const LAST_CHANGE_DATE_FORMAT: &str = "%v %T"; +/// CONFIG: websocket size limit +pub const WEBSOCKET_MAX_BYTES: usize = 2 * 1024 * 1024; + #[derive(Clone)] pub enum TxControlMessage { Payload(serde_json::Value), diff --git a/src/websocket/handler.rs b/src/websocket/handler.rs index 2cd917b..4a4e0e9 100644 --- a/src/websocket/handler.rs +++ b/src/websocket/handler.rs @@ -1,13 +1,21 @@ use axum::{ Json, + body::Bytes, extract::{Request, State, WebSocketUpgrade, ws::WebSocket}, response::IntoResponse, }; use futures::StreamExt; -use log::{info, warn}; +use log::{error, info, warn}; use redis::TypedCommands; -use std::{fs::File, io::BufWriter, sync::Arc}; -use tokio::{sync::Mutex, sync::mpsc, time::Instant}; +use std::{ + fs::File, + io::BufWriter, + sync::{Arc, RwLock}, +}; +use tokio::{ + sync::{Mutex, mpsc}, + time::Instant, +}; use uuid::Uuid; use super::{core::*, model::*}; @@ -20,19 +28,97 @@ pub async fn post_from_other_system( State(state): State>, Json(msg): Json, ) -> impl IntoResponse { - let sys_payload = match serde_json::to_value(&msg) { - Ok(s) => s, - Err(_) => { + info!("triggering post callback"); + let target_receiver = if let Some(to) = msg.payload.get("to") { + to.as_str().unwrap_or_default().to_string() + } else { + "".to_string() + }; + + let from_service = if let Some(from) = msg.payload.get("from") { + from.as_str().unwrap_or_default().to_string() + } else { + "".to_string() + }; + + info!("posting from {from_service} to {target_receiver}"); + + match target_receiver.as_str() { + "*" => { + // send all + + info!("sending all receivers ..."); + let clients = { + let lock = state.connectors_mapping.read().unwrap(); + lock.clients.clone() + }; + + info!("acquired read lock"); + let mut send_success_count = 0; + let mut send_fail_count = 0; + let mut fail_cause = String::new(); + for (uid, tx) in clients.iter() { + if let Err(e) = tx + .send(TxControlMessage::Payload(serde_json::json!(msg))) + .await + { + send_fail_count += 1; + error!("send to {uid} fail: {e}"); + fail_cause.push_str(format!("{uid}:{e}\n").as_str()); + } else { + send_success_count += 1; + info!("send to {uid} success!"); + } + } + + info!("[send-all] success: {send_success_count}, fail: {send_fail_count}"); + return ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, - "cannot create payload", + axum::http::StatusCode::OK, + serde_json::json!({ + "success": send_success_count, + "fail": send_fail_count, + "cause": fail_cause + }) + .to_string(), + ) + .into_response(); + } + recv if !target_receiver.is_empty() => { + let recv_sender = { + let lock = state.connectors_mapping.read().unwrap(); + lock.clients.get(recv).cloned() + }; + info!("[send-single] acquired client"); + + if let Some(recv_tx) = recv_sender { + info!("[send-single] acquired client ok, sending ..."); + if let Err(e) = recv_tx + .send(TxControlMessage::Payload(serde_json::json!(msg))) + .await + { + return ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + format!("send fail: {e}"), + ) + .into_response(); + } else { + info!("[send-single] send success"); + return (axum::http::StatusCode::OK, "send success").into_response(); + } + } else { + error!("target user is not connected, user may be offline or disconnected!"); + return (axum::http::StatusCode::BAD_REQUEST, "user not found").into_response(); + } + } + _ => { + warn!("payload is incorrect from {from_service}, sender was not provided receiver"); + return ( + axum::http::StatusCode::BAD_REQUEST, + "receiver is empty or wrong type", ) .into_response(); } - }; - match state.system_tx.send(sys_payload) { - Ok(_) => (axum::http::StatusCode::OK, "").into_response(), - Err(_) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "send fail").into_response(), } } @@ -91,6 +177,22 @@ pub async fn request_api_session_key( (axum::http::StatusCode::OK, generated).into_response() } +pub async fn get_online_users(State(state): State>) -> impl IntoResponse { + let on_connected_clients: Vec = { + let lock = state.connectors_mapping.read().unwrap(); + lock.clients.keys().map(|x| x.to_string()).collect() + }; + + ( + axum::http::StatusCode::OK, + serde_json::json!({ + "online": on_connected_clients + }) + .to_string(), + ) + .into_response() +} + /// Main websocket handler pub async fn websocket_handler( State(state): State>, @@ -115,27 +217,18 @@ pub async fn websocket_handler( return (axum::http::StatusCode::FORBIDDEN, "".to_string()).into_response(); } - // let mut uid_n = String::new(); + // TODO: Add more headers? - // if let Some(uid) = headers.get("x-auth-uid") { - // let uid_from_web = uid.to_str().unwrap_or_default().to_string(); - - // uid_n = uid_from_web; - // info!("user connect {uid_n}"); - // } - - // if uid_n.is_empty() { - // return (axum::http::StatusCode::BAD_REQUEST, "").into_response(); - // } - - ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error)) + ws.max_frame_size(WEBSOCKET_MAX_BYTES) + .max_message_size(WEBSOCKET_MAX_BYTES) + .on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error)) .on_upgrade(async |s| handle_socket(s, state_clone, hub_clone).await.unwrap_or(())) } async fn handle_socket( socket: WebSocket, state: Arc, - hub: Arc>, + hub: Arc>, ) -> Result<(), Box> { let (sender, receiver) = socket.split(); // internal channel @@ -163,27 +256,32 @@ async fn handle_socket( info!("{} connected", temp_session); { - let mut h = hub.lock().await; + let mut h = hub.write().unwrap(); h.clients.insert(temp_session.clone(), tx.clone()); } - let user_sys_rx = state.system_tx.subscribe(); + // NOTE: disable from cause system tx could directly send to client rx + // without sending to system rx. + // let user_sys_rx = state.system_tx.subscribe(); let last_seen = Arc::new(Mutex::new(Instant::now())); let reader_last_seen = last_seen.clone(); let watchdog_last_seen = last_seen.clone(); - let sender = tokio::spawn(super::rw::write(sender, rx, user.clone(), hub.clone())); + let hub_for_write = hub.clone(); + let hub_for_read = hub.clone(); + + let sender = tokio::spawn(super::rw::write(sender, rx, user.clone(), hub_for_write)); let reader = tokio::spawn(super::rw::read( state, receiver, tx.clone(), reader_last_seen, user.clone(), - hub.clone(), + hub_for_read, )); - let callback_to_client = super::rw::recv_sys_msg_send_back_client(tx.clone(), user_sys_rx); + // let callback_to_client = super::rw::recv_sys_msg_send_back_client(tx.clone(), user_sys_rx); let watchdog = super::tasks::watchdog::get_watchdog_task( tx, @@ -192,27 +290,40 @@ async fn handle_socket( hub.clone(), ); - let (rf, sf, cbc, wds) = tokio::join!(reader, sender, callback_to_client, watchdog); + let (rf, sf, wds) = tokio::join!(reader, sender, watchdog); if let Ok(rf_js) = rf && let Ok(sf_js) = sf { + let user = user.clone().lock().await.to_string(); info!( "read end ok: {}, write end ok: {} [{}]", rf_js.is_ok(), sf_js.is_ok(), - user.clone().lock().await.to_string() + user.clone() ); - if !cbc.is_finished() { - info!("sys rx still running"); - cbc.abort(); - - if cbc.await.unwrap_err().is_cancelled() { - info!("sys rx force stop ..."); - } - } if !wds.is_finished() { info!("watchdog still existed"); + + wds.abort(); + + if wds.await.unwrap_err().is_cancelled() { + info!("watchdog force stop"); + } + } + + { + let mut lock = hub.write().unwrap(); + + if lock.clients.contains_key(&user) { + warn!("user still existed! {user}, removing key ..."); + lock.clients.remove(&user); + // check again + warn!( + "after remove user, exist: {}", + lock.clients.contains_key(&user) + ); + } } } diff --git a/src/websocket/rw.rs b/src/websocket/rw.rs index f368bfa..edd134b 100644 --- a/src/websocket/rw.rs +++ b/src/websocket/rw.rs @@ -4,7 +4,8 @@ use crate::{ websocket::{plugins::call_plugin_if_existed, tasks}, }; use std::{ - sync::{Arc, atomic::AtomicBool}, + collections::HashMap, + sync::{Arc, RwLock}, time::Duration, }; @@ -32,7 +33,7 @@ pub async fn read( tx: Sender, last_seen: Arc>, // cmd_atom: crossbeam_queue::ArrayQueue, uid: Arc>, - hub: Arc>, + hub: Arc>, ) -> Result<(), Box> { let redis = state.redis_cli.clone(); let config = state.dev_config.clone(); @@ -127,9 +128,15 @@ pub async fn read( .await; } "sheet" if req.payload.is_some() => { - if tasks::sheet::handle_sheet_request(config.clone(), redis.clone(), req) - .await - .is_err() + if tasks::sheet::handle_sheet_request( + config.clone(), + redis.clone(), + tx.clone(), + req, + uid_clone.clone(), + ) + .await + .is_err() { continue; } @@ -199,7 +206,7 @@ pub async fn read( // remove current uid { - let mut h = hub.try_lock().unwrap(); + let mut h = hub.write().unwrap(); let curr_user = uid.try_lock().unwrap().to_string(); if let Some(ent) = h.clients.remove_entry(&curr_user) { @@ -228,7 +235,10 @@ pub async fn read( } } - info!("[read] canceling sys rx ..."); + info!( + "[read][{}] canceling sys rx ...", + uid_clone.lock().await.to_string() + ); Ok(()) } @@ -237,8 +247,11 @@ pub async fn write( mut sender: SplitSink, mut rx: Receiver, uid: Arc>, - hub: Arc>, + hub: Arc>, ) -> Result<(), Box> { + // only allow each stream type for 1 request + let pending_stream_requests = Arc::new(RwLock::new(HashMap::new())); + while let Some(res) = rx.recv().await { match res { TxControlMessage::Payload(res) => { @@ -266,10 +279,50 @@ pub async fn write( && let Some(recv_ident) = res_payload_val.get("to") && let Some(recv_ident_str) = recv_ident.as_str() && (current_uid.to_string().eq(recv_ident_str) - || current_uid.to_string().eq("*")) + || recv_ident_str.to_string().eq("*")) { let payload_size = res.to_string().len(); + if tasks::sheet::is_tx_stream_type(res.clone()) { + let rid = res_payload_val + .get("request_id") + .unwrap_or_default() + .as_str(); + let rtype = res_n + .get("type") + .unwrap_or_default() + .as_str() + .unwrap_or_default(); + + let mut pending = { + let lock = pending_stream_requests.write().unwrap(); + lock + }; + + info!( + "register stream type: {rtype}, with request id: {} and allow add to pending: {}", + rid.unwrap().to_string(), + !pending.contains_key(rtype) + ); + + if !rtype.is_empty() && !pending.contains_key(rtype) { + // allow set pending + pending.insert(rtype.to_string(), rid.unwrap().to_string()); + info!( + "add pending stream: {} ---> {}", + rtype, + rid.unwrap().to_string() + ); + } else if rtype.is_empty() { + warn!("request stream type is empty!"); + } else { + // blocking + warn!("request more than once, please wait until current finish!"); + } + + continue; + } + if payload_size >= 100000 { // large payload warn!( @@ -278,16 +331,101 @@ pub async fn write( ); } - let _ = sender.send(res.to_string().into()).await; + // handle check if response has been set as pending before + + let stream_ref = res_payload_val.get("ref").unwrap_or_default(); + + let pending_has_key = { + let lock = pending_stream_requests.read().unwrap(); + lock.contains_key( + format!("stream-{}", stream_ref.as_str().unwrap_or_default()).as_str(), + ) + }; + + let stream_chunk_id = { + let lock = pending_stream_requests.read().unwrap(); + lock.get( + format!("stream-{}", stream_ref.as_str().unwrap_or_default()).as_str(), + ) + .cloned() + .unwrap_or_default() + }; + + if pending_has_key { + // has set, do iterate now + + // gen payload size + let size_per_payload = 10000; + + let chars: Vec = res.to_string().chars().collect(); + let split = &chars + .chunks(size_per_payload) + .map(|chunk| chunk.iter().collect::()) + .collect::>(); + + let header = serde_json::json!({ + "type": format!("raw_stream_{}", stream_ref.as_str().unwrap_or_default()), + "payload": { + "to": recv_ident_str, + "total_chunks": split.len(), + "size_per_chunk": size_per_payload, + "request_id": stream_chunk_id + } + }); + + let footer = serde_json::json!({ + "type": format!("raw_stream_end_{}", stream_ref.as_str().unwrap_or_default()), + "payload": { + "to": recv_ident_str, + "request_id": stream_chunk_id + } + }); + + let _ = sender.send(header.to_string().into()).await?; + + for (idx, raw_payload) in split.iter().enumerate() { + let raw_chunk_payload = serde_json::json!({ + "type": format!("raw_stream_chunk_{}", stream_ref.as_str().unwrap_or_default()), + "payload": { + "to": recv_ident_str, + "raw": raw_payload.clone(), + "idx": idx, + "request_id": stream_chunk_id + } + }); + let _ = sender.feed(raw_chunk_payload.to_string().into()).await; + } + + if let Err(e) = sender.flush().await { + error!("flushing stream failed: {e}"); + continue; + } + + { + // end stream + let mut lock = pending_stream_requests.write().unwrap(); + + lock.remove( + format!("stream-{}", stream_ref.as_str().unwrap_or_default()) + .as_str(), + ); + } + + let _ = sender.send(footer.to_string().into()).await; + + continue; + } else { + if let Err(e) = sender.send(res.to_string().into()).await { + error!("[write] send payload fail; len={payload_size}, reason: {e}"); + } + } } else { // show error by case - let clients: Vec = hub - .lock() - .await - .clients - .keys() - .map(|x| x.to_string()) - .collect(); + let clients: Vec = { + let lock = hub.read().unwrap(); + + lock.clients.keys().map(|x| x.to_string()).collect() + }; // step errors if let Some(res_n) = res.as_object() @@ -301,7 +439,7 @@ pub async fn write( if clients.contains(&recv_ident_str.to_string()) && current_uid.ne(&recv_ident_str.to_string()) { - warn!("oops! receiving other receiver's messages. Ignore this"); + // warn!("oops! receiving other receiver's messages. Ignore this"); } else { error!("receiver not existed or already went offline"); } diff --git a/src/websocket/tasks/auth.rs b/src/websocket/tasks/auth.rs index b98ba25..3e6da20 100644 --- a/src/websocket/tasks/auth.rs +++ b/src/websocket/tasks/auth.rs @@ -1,7 +1,8 @@ use crate::app::*; use crate::websocket::{core::*, model::*}; use log::{error, info, warn}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; + use tokio::sync::{Mutex, mpsc::Sender}; /// Handle request of command type from websocket (read) @@ -9,7 +10,7 @@ pub async fn handle_auth_request( state: Arc, tx: Sender, req: WebsocketMessageRequest, - hub: Arc>, + hub: Arc>, curr_uid: Arc>, ) -> WebsocketMessageResult { // do command send to other services @@ -28,7 +29,10 @@ pub async fn handle_auth_request( if !new_uid.is_empty() { let old_uid = curr_uid.try_lock().unwrap().to_string(); - let mut h = hub.try_lock().unwrap(); + let mut h = { + let lock = hub.write().unwrap(); + lock + }; if let Some(ent) = h.clients.remove_entry(&old_uid) { let curr_connection = ent.1; @@ -36,7 +40,7 @@ pub async fn handle_auth_request( // case auth success but already have some tx left if let Some(old_tx) = h.clients.insert(new_uid.clone(), curr_connection) { - warn!("disconnecting old connection"); + warn!("[auth][{}] disconnecting old connection", old_uid.clone()); let _ = old_tx.send(TxControlMessage::CloseExist); } info!("update re-new auth successful ---> {}", new_uid.clone()); diff --git a/src/websocket/tasks/sheet.rs b/src/websocket/tasks/sheet.rs index f62b2a3..8d72277 100644 --- a/src/websocket/tasks/sheet.rs +++ b/src/websocket/tasks/sheet.rs @@ -1,15 +1,20 @@ +use std::sync::Arc; + use crate::{ app::DevConfig, websocket::{core::*, model::*}, }; use log::{error, info}; use redis::TypedCommands; +use tokio::sync::{Mutex, mpsc::Sender}; /// Handle request of sheet type from websocket (read) pub async fn handle_sheet_request( config: DevConfig, redis: redis::Client, + tx: Sender, req: WebsocketMessageRequest, + uid_clone: Arc>, ) -> WebsocketMessageResult { // CommandRequestPayload struct-like @@ -26,17 +31,28 @@ pub async fn handle_sheet_request( } }; - info!( - "get sheet request: {}, {:?}", - payload_sheet_request.srv_name, payload_sheet_request - ); + // info!( + // "get sheet request: {}, {:?}", + // payload_sheet_request.srv_name, payload_sheet_request + // ); let parameters = payload_sheet_request .values .get("param") .unwrap_or_default(); - // TODO: will be changed to config from yaml file + let stream_mode = payload_sheet_request + .values + .get("stream") + .unwrap_or_default(); + + let request_id = payload_sheet_request + .values + .get("request_id") + .unwrap_or_default(); + + let uidd = uid_clone.clone().lock().await.to_string(); + let ch_target = if let Some(pm) = parameters.as_str() && config.check_sheet_endpoints(pm) { @@ -58,5 +74,39 @@ pub async fn handle_sheet_request( error!("error on publish result cmd: {e:?}"); } + if let Some(stream_flag) = stream_mode.as_bool() + && let Some(request_id) = request_id.as_str() + && stream_flag + && !request_id.is_empty() + { + let _ = tx + .send(TxControlMessage::Payload(serde_json::json!({ + "type": format!("stream-{ch_target}"), + "payload": { + "request_id": request_id.to_string(), + "to": uidd + } + }))) + .await; + } + Ok(()) } + +pub fn is_tx_stream_type(raw: serde_json::Value) -> bool { + // expect request id + // type must start with stream + let tx_type = raw.get("type").unwrap_or_default(); + let payload = raw.get("payload").unwrap_or_default(); + + if let Some(tx_t) = tx_type.as_str() + && let Some(request_id) = payload.get("request_id") + && let Some(rid) = request_id.as_str() + && tx_t.starts_with("stream-") + && !rid.is_empty() + { + true + } else { + false + } +} diff --git a/src/websocket/tasks/watchdog.rs b/src/websocket/tasks/watchdog.rs index ddc7108..a5ac6f3 100644 --- a/src/websocket/tasks/watchdog.rs +++ b/src/websocket/tasks/watchdog.rs @@ -1,6 +1,10 @@ use crate::{app::Hub, websocket::core::*}; use log::{info, warn}; -use std::{ops::Sub, sync::Arc, time::Duration}; +use std::{ + ops::Sub, + sync::{Arc, RwLock}, + time::Duration, +}; use tokio::{ sync::{Mutex, mpsc::Sender}, task::JoinHandle, @@ -11,16 +15,21 @@ pub async fn get_watchdog_task( tx: Sender, watchdog_last_seen: Arc>, user: Arc>, - hub: Arc>, + hub: Arc>, ) -> JoinHandle<()> { tokio::spawn(async move { let uc = user.clone().lock().await.to_string(); info!("start watchdog for {uc}"); + loop { tokio::time::sleep(Duration::from_secs(2)).await; + let h = { + let lock = hub.read().unwrap(); + lock.clone() + }; + { - let h = hub.lock().await; let curr_user = user.lock().await.to_string(); // info!("{}: checking invalid ...", curr_user);