feat: add uid field check, kick old ws connection
Signed-off-by: Pakin <pakin.t@forth.co.th>
This commit is contained in:
parent
fecdb94841
commit
819bd08bc3
1 changed files with 158 additions and 44 deletions
202
src/main.rs
202
src/main.rs
|
|
@ -1,5 +1,11 @@
|
|||
use std::{
|
||||
collections::VecDeque, env, fs::File, io::Read, path::PathBuf, sync::Arc, time::Duration,
|
||||
collections::{HashMap, VecDeque},
|
||||
env,
|
||||
fs::File,
|
||||
io::Read,
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use async_compression::tokio::bufread::BrotliDecoder;
|
||||
|
|
@ -9,6 +15,7 @@ use axum::{
|
|||
State, WebSocketUpgrade,
|
||||
ws::{CloseFrame, Message, WebSocket},
|
||||
},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
serve::ListenerExt,
|
||||
|
|
@ -135,11 +142,7 @@ async fn invoke_checkout_request(
|
|||
}
|
||||
|
||||
pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> {
|
||||
Router::new()
|
||||
// .route("/", get(get_root_files))
|
||||
.route("/ws", get(websocket_handler))
|
||||
// .route("/edit", post())
|
||||
// .route("/{country}/", method_router)
|
||||
Router::new().route("/ws", get(websocket_handler))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
@ -149,6 +152,15 @@ struct SysMessage {
|
|||
payload: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Request to generate api rotated key
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiSessionRequest {
|
||||
// uid from user login
|
||||
uid: serde_json::Value,
|
||||
//
|
||||
timestamp: serde_json::Value,
|
||||
}
|
||||
|
||||
async fn post_from_other_system(
|
||||
State(mut state): State<Arc<AppState>>,
|
||||
Json(msg): Json<SysMessage>,
|
||||
|
|
@ -169,22 +181,72 @@ async fn post_from_other_system(
|
|||
}
|
||||
}
|
||||
|
||||
async fn request_api_session_key(
|
||||
State(mut state): State<Arc<AppState>>,
|
||||
Json(msg): Json<ApiSessionRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let mut rcl = state.redis_cli.clone();
|
||||
|
||||
// gen key
|
||||
let generated = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
|
||||
if let Some(uid_n) = msg.uid.as_str() {
|
||||
let saved_key = format!("{uid_n}-ak");
|
||||
let _ = rcl.set_ex(saved_key, generated.clone(), 3600);
|
||||
}
|
||||
|
||||
(axum::http::StatusCode::OK, generated).into_response()
|
||||
}
|
||||
|
||||
async fn websocket_handler(
|
||||
State(mut state): State<Arc<AppState>>,
|
||||
ws: WebSocketUpgrade,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let state_clone = Arc::clone(&state);
|
||||
let hub_clone = Arc::clone(&state_clone.connectors_mapping);
|
||||
|
||||
let mut uid_n = String::new();
|
||||
|
||||
if let Some(uid) = headers.get("x-auth-uid") {
|
||||
let uid_from_web = uid.to_str().unwrap_or_default().to_string();
|
||||
|
||||
uid_n = uid_from_web;
|
||||
info!("user connect {uid_n}");
|
||||
}
|
||||
|
||||
if uid_n.is_empty() {
|
||||
return (axum::http::StatusCode::BAD_REQUEST, "").into_response();
|
||||
}
|
||||
|
||||
ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error))
|
||||
.on_upgrade(async |s| handle_socket(s, state_clone).await.unwrap_or(()))
|
||||
.on_upgrade(async |s| {
|
||||
handle_socket(s, state_clone, uid_n, hub_clone)
|
||||
.await
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
mut state: Arc<AppState>,
|
||||
uid: String,
|
||||
hub: Arc<Mutex<Hub>>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
// internal channel
|
||||
let (tx, mut rx) = mpsc::channel::<serde_json::Value>(2);
|
||||
let (tx, mut rx) = mpsc::channel::<TxControlMessage>(2);
|
||||
|
||||
{
|
||||
// register & kick old tx
|
||||
|
||||
let mut h = hub.try_lock().unwrap();
|
||||
if let Some(old_tx) = h.clients.insert(uid, tx.clone()) {
|
||||
warn!("disconnect old connection");
|
||||
let _ = old_tx.send(TxControlMessage::CloseExist);
|
||||
}
|
||||
}
|
||||
|
||||
let user_sys_rx = state.system_tx.subscribe();
|
||||
|
||||
let last_seen = Arc::new(Mutex::new(Instant::now()));
|
||||
|
|
@ -208,9 +270,9 @@ async fn handle_socket(
|
|||
if last.elapsed() > TIMEOUT {
|
||||
warn!("Timeout close connection");
|
||||
let _ = tx
|
||||
.send(serde_json::json!({
|
||||
.send(TxControlMessage::Payload(serde_json::json!({
|
||||
"timeout": "watchdog"
|
||||
}))
|
||||
})))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
|
|
@ -331,7 +393,7 @@ async fn read(
|
|||
// redis: redis::Client,
|
||||
mut state: Arc<AppState>,
|
||||
mut receiver: SplitStream<WebSocket>,
|
||||
tx: Sender<serde_json::Value>,
|
||||
tx: Sender<TxControlMessage>,
|
||||
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
|
||||
last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
|
@ -343,7 +405,10 @@ async fn read(
|
|||
// Send back to client from services
|
||||
while let Ok(s_msg) = system_rx.recv().await {
|
||||
if convert_sys_msg_command(&s_msg).is_some()
|
||||
&& let Some(err) = tx_to_client.send(s_msg).await.err()
|
||||
&& let Some(err) = tx_to_client
|
||||
.send(TxControlMessage::Payload(s_msg))
|
||||
.await
|
||||
.err()
|
||||
{
|
||||
println!("[SYS] failed to send back to client: {err}");
|
||||
}
|
||||
|
|
@ -355,6 +420,7 @@ async fn read(
|
|||
Message::Text(t) => {
|
||||
let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?;
|
||||
let req_clone = req.clone();
|
||||
info!("get msg: {}", req.type_w);
|
||||
match req.type_w.as_str() {
|
||||
"recipe" if req.payload.is_some() => {
|
||||
// guard expect value
|
||||
|
|
@ -438,7 +504,7 @@ async fn read(
|
|||
error!("File corrupted, invalid json format");
|
||||
}
|
||||
|
||||
let _ = tx.send(serde_json::json!({
|
||||
let _ = tx.send(TxControlMessage::Payload(serde_json::json!({
|
||||
"type": "notify",
|
||||
"payload": {
|
||||
"from": "system_tx",
|
||||
|
|
@ -446,7 +512,7 @@ async fn read(
|
|||
"msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version),
|
||||
"to": ""
|
||||
}
|
||||
})).await;
|
||||
}))).await;
|
||||
|
||||
return Err(e.into());
|
||||
}
|
||||
|
|
@ -549,7 +615,7 @@ async fn read(
|
|||
if let Some(_) = state.system_tx.send(p).err() {
|
||||
info!("failed to send command request");
|
||||
let _ = tx
|
||||
.send(serde_json::json!({
|
||||
.send(TxControlMessage::Payload(serde_json::json!({
|
||||
"type": "notify",
|
||||
"payload": {
|
||||
"from": "system_tx",
|
||||
|
|
@ -557,7 +623,7 @@ async fn read(
|
|||
"msg": "send request fail",
|
||||
"to": ""
|
||||
}
|
||||
}))
|
||||
})))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
@ -652,31 +718,46 @@ async fn read(
|
|||
|
||||
async fn write(
|
||||
mut sender: SplitSink<WebSocket, Message>,
|
||||
mut rx: Receiver<serde_json::Value>,
|
||||
mut rx: Receiver<TxControlMessage>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
while let Some(res) = rx.recv().await {
|
||||
// force close
|
||||
if let Some(force_timeout_by) = res.get("timeout")
|
||||
&& let Some(from_who) = force_timeout_by.as_str()
|
||||
&& from_who.eq("watchdog")
|
||||
{
|
||||
warn!("receive close from watchdog");
|
||||
let _ = sender.send(Message::Close(None)).await;
|
||||
break;
|
||||
match res {
|
||||
TxControlMessage::Payload(res) => {
|
||||
// force close
|
||||
if let Some(force_timeout_by) = res.get("timeout")
|
||||
&& let Some(from_who) = force_timeout_by.as_str()
|
||||
&& from_who.eq("watchdog")
|
||||
{
|
||||
warn!("receive close from watchdog");
|
||||
let _ = sender.send(Message::Close(None)).await;
|
||||
break;
|
||||
}
|
||||
|
||||
// if let Some(res_n) = res.as_object()
|
||||
// && let Some(res_payload) = res_n.get("payload")
|
||||
// && let Some(res_payload_val) = res_payload.as_object()
|
||||
// && let Some(recv_ident) = res_payload_val.get("to")
|
||||
// && let Some(recv_ident_str) = recv_ident.as_str()
|
||||
// {}
|
||||
|
||||
let payload_size = res.to_string().len();
|
||||
|
||||
if payload_size >= 100000 {
|
||||
// large payload
|
||||
warn!(
|
||||
"sending large payload to client ... ({})",
|
||||
res.to_string().len()
|
||||
);
|
||||
}
|
||||
|
||||
let _ = sender.send(res.to_string().into()).await;
|
||||
}
|
||||
TxControlMessage::CloseExist => {
|
||||
let _ = sender.close().await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let payload_size = res.to_string().len();
|
||||
|
||||
if payload_size >= 100000 {
|
||||
// large payload
|
||||
warn!(
|
||||
"sending large payload to client ... ({})",
|
||||
res.to_string().len()
|
||||
);
|
||||
}
|
||||
|
||||
let _ = sender.send(res.to_string().into()).await;
|
||||
|
||||
// リミットブレく - limito breaku!! (uncomment to slow down messages)
|
||||
// let _ = tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
|
|
@ -721,7 +802,7 @@ fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt: i3
|
|||
|
||||
async fn throttle_send_recipe(
|
||||
recipe: &Recipe,
|
||||
tx: &Sender<serde_json::Value>,
|
||||
tx: &Sender<TxControlMessage>,
|
||||
country: String,
|
||||
version: String,
|
||||
) {
|
||||
|
|
@ -748,7 +829,7 @@ async fn throttle_send_recipe(
|
|||
let sid = ss.get_id();
|
||||
info!("starting {sid}");
|
||||
|
||||
if let Some(err) = tx.send(ss.as_msg()).await.err() {
|
||||
if let Some(err) = tx.send(TxControlMessage::Payload(ss.as_msg())).await.err() {
|
||||
println!("ERR: send tx error, {err:?}");
|
||||
}
|
||||
|
||||
|
|
@ -758,7 +839,7 @@ async fn throttle_send_recipe(
|
|||
let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec());
|
||||
|
||||
// no validate
|
||||
if let Some(err) = tx.send(sda.as_msg()).await.err() {
|
||||
if let Some(err) = tx.send(TxControlMessage::Payload(sda.as_msg())).await.err() {
|
||||
println!("ERR: send tx error, {err:?}");
|
||||
}
|
||||
}
|
||||
|
|
@ -770,7 +851,11 @@ async fn throttle_send_recipe(
|
|||
|
||||
let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec());
|
||||
|
||||
if let Some(err) = tx.send(extra_matset.as_msg()).await.err() {
|
||||
if let Some(err) = tx
|
||||
.send(TxControlMessage::Payload(extra_matset.as_msg()))
|
||||
.await
|
||||
.err()
|
||||
{
|
||||
println!("ERR: send tx extra error: {err:?}");
|
||||
}
|
||||
}
|
||||
|
|
@ -779,7 +864,11 @@ async fn throttle_send_recipe(
|
|||
for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() {
|
||||
let curr_ch_id = format!("{mat_exid}_tl{index}");
|
||||
let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec());
|
||||
if let Some(err) = tx.send(extra_topplist.as_msg()).await.err() {
|
||||
if let Some(err) = tx
|
||||
.send(TxControlMessage::Payload(extra_topplist.as_msg()))
|
||||
.await
|
||||
.err()
|
||||
{
|
||||
println!("ERR: send tx extra2 error: {err:?}");
|
||||
}
|
||||
}
|
||||
|
|
@ -788,7 +877,11 @@ async fn throttle_send_recipe(
|
|||
for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() {
|
||||
let curr_ch_id = format!("{mat_exid}_tg{index}");
|
||||
let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec());
|
||||
if let Some(err) = tx.send(extra_toppgrp.as_msg()).await.err() {
|
||||
if let Some(err) = tx
|
||||
.send(TxControlMessage::Payload(extra_toppgrp.as_msg()))
|
||||
.await
|
||||
.err()
|
||||
{
|
||||
println!("ERR: send tx extra2 error: {err:?}");
|
||||
}
|
||||
}
|
||||
|
|
@ -803,15 +896,32 @@ async fn throttle_send_recipe(
|
|||
// return sid;
|
||||
let end_msg = StreamDataEnd::new(&sid);
|
||||
|
||||
if let Some(err) = tx.send(end_msg.as_msg()).await.err() {
|
||||
if let Some(err) = tx
|
||||
.send(TxControlMessage::Payload(end_msg.as_msg()))
|
||||
.await
|
||||
.err()
|
||||
{
|
||||
println!("ERR: send tx error, {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Hub {
|
||||
pub clients: HashMap<String, Sender<TxControlMessage>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum TxControlMessage {
|
||||
Payload(serde_json::Value),
|
||||
CloseExist,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
dev_config: DevConfig,
|
||||
redis_cli: redis::Client,
|
||||
system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
|
||||
// saved client uid:client uuid
|
||||
connectors_mapping: Arc<Mutex<Hub>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
|
|
@ -831,6 +941,9 @@ impl AppState {
|
|||
dev_config,
|
||||
redis_cli,
|
||||
system_tx,
|
||||
connectors_mapping: Arc::new(Mutex::new(Hub {
|
||||
clients: HashMap::new(),
|
||||
})),
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -988,6 +1101,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
let app = Router::new()
|
||||
// .route("/sessionLogin", post(session_login))
|
||||
.route("/syscb", post(post_from_other_system))
|
||||
.route("/regas", post(request_api_session_key))
|
||||
.nest("/recipe", rp_router)
|
||||
.nest("/docs", doc_router)
|
||||
.with_state(app_state);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue