feat: add uid field check, kick old ws connection
Signed-off-by: Pakin <pakin.t@forth.co.th>
This commit is contained in:
parent
fecdb94841
commit
819bd08bc3
1 changed files with 158 additions and 44 deletions
202
src/main.rs
202
src/main.rs
|
|
@ -1,5 +1,11 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::VecDeque, env, fs::File, io::Read, path::PathBuf, sync::Arc, time::Duration,
|
collections::{HashMap, VecDeque},
|
||||||
|
env,
|
||||||
|
fs::File,
|
||||||
|
io::Read,
|
||||||
|
path::PathBuf,
|
||||||
|
sync::Arc,
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use async_compression::tokio::bufread::BrotliDecoder;
|
use async_compression::tokio::bufread::BrotliDecoder;
|
||||||
|
|
@ -9,6 +15,7 @@ use axum::{
|
||||||
State, WebSocketUpgrade,
|
State, WebSocketUpgrade,
|
||||||
ws::{CloseFrame, Message, WebSocket},
|
ws::{CloseFrame, Message, WebSocket},
|
||||||
},
|
},
|
||||||
|
http::HeaderMap,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
serve::ListenerExt,
|
serve::ListenerExt,
|
||||||
|
|
@ -135,11 +142,7 @@ async fn invoke_checkout_request(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> {
|
pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> {
|
||||||
Router::new()
|
Router::new().route("/ws", get(websocket_handler))
|
||||||
// .route("/", get(get_root_files))
|
|
||||||
.route("/ws", get(websocket_handler))
|
|
||||||
// .route("/edit", post())
|
|
||||||
// .route("/{country}/", method_router)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
|
@ -149,6 +152,15 @@ struct SysMessage {
|
||||||
payload: serde_json::Value,
|
payload: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Request to generate api rotated key
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ApiSessionRequest {
|
||||||
|
// uid from user login
|
||||||
|
uid: serde_json::Value,
|
||||||
|
//
|
||||||
|
timestamp: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
async fn post_from_other_system(
|
async fn post_from_other_system(
|
||||||
State(mut state): State<Arc<AppState>>,
|
State(mut state): State<Arc<AppState>>,
|
||||||
Json(msg): Json<SysMessage>,
|
Json(msg): Json<SysMessage>,
|
||||||
|
|
@ -169,22 +181,72 @@ async fn post_from_other_system(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn request_api_session_key(
|
||||||
|
State(mut state): State<Arc<AppState>>,
|
||||||
|
Json(msg): Json<ApiSessionRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let mut rcl = state.redis_cli.clone();
|
||||||
|
|
||||||
|
// gen key
|
||||||
|
let generated = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||||
|
|
||||||
|
if let Some(uid_n) = msg.uid.as_str() {
|
||||||
|
let saved_key = format!("{uid_n}-ak");
|
||||||
|
let _ = rcl.set_ex(saved_key, generated.clone(), 3600);
|
||||||
|
}
|
||||||
|
|
||||||
|
(axum::http::StatusCode::OK, generated).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn websocket_handler(
|
async fn websocket_handler(
|
||||||
State(mut state): State<Arc<AppState>>,
|
State(mut state): State<Arc<AppState>>,
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
|
headers: HeaderMap,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let state_clone = Arc::clone(&state);
|
let state_clone = Arc::clone(&state);
|
||||||
|
let hub_clone = Arc::clone(&state_clone.connectors_mapping);
|
||||||
|
|
||||||
|
let mut uid_n = String::new();
|
||||||
|
|
||||||
|
if let Some(uid) = headers.get("x-auth-uid") {
|
||||||
|
let uid_from_web = uid.to_str().unwrap_or_default().to_string();
|
||||||
|
|
||||||
|
uid_n = uid_from_web;
|
||||||
|
info!("user connect {uid_n}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if uid_n.is_empty() {
|
||||||
|
return (axum::http::StatusCode::BAD_REQUEST, "").into_response();
|
||||||
|
}
|
||||||
|
|
||||||
ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error))
|
ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error))
|
||||||
.on_upgrade(async |s| handle_socket(s, state_clone).await.unwrap_or(()))
|
.on_upgrade(async |s| {
|
||||||
|
handle_socket(s, state_clone, uid_n, hub_clone)
|
||||||
|
.await
|
||||||
|
.unwrap_or(())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(
|
async fn handle_socket(
|
||||||
mut socket: WebSocket,
|
mut socket: WebSocket,
|
||||||
mut state: Arc<AppState>,
|
mut state: Arc<AppState>,
|
||||||
|
uid: String,
|
||||||
|
hub: Arc<Mutex<Hub>>,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (mut sender, mut receiver) = socket.split();
|
let (mut sender, mut receiver) = socket.split();
|
||||||
// internal channel
|
// internal channel
|
||||||
let (tx, mut rx) = mpsc::channel::<serde_json::Value>(2);
|
let (tx, mut rx) = mpsc::channel::<TxControlMessage>(2);
|
||||||
|
|
||||||
|
{
|
||||||
|
// register & kick old tx
|
||||||
|
|
||||||
|
let mut h = hub.try_lock().unwrap();
|
||||||
|
if let Some(old_tx) = h.clients.insert(uid, tx.clone()) {
|
||||||
|
warn!("disconnect old connection");
|
||||||
|
let _ = old_tx.send(TxControlMessage::CloseExist);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let user_sys_rx = state.system_tx.subscribe();
|
let user_sys_rx = state.system_tx.subscribe();
|
||||||
|
|
||||||
let last_seen = Arc::new(Mutex::new(Instant::now()));
|
let last_seen = Arc::new(Mutex::new(Instant::now()));
|
||||||
|
|
@ -208,9 +270,9 @@ async fn handle_socket(
|
||||||
if last.elapsed() > TIMEOUT {
|
if last.elapsed() > TIMEOUT {
|
||||||
warn!("Timeout close connection");
|
warn!("Timeout close connection");
|
||||||
let _ = tx
|
let _ = tx
|
||||||
.send(serde_json::json!({
|
.send(TxControlMessage::Payload(serde_json::json!({
|
||||||
"timeout": "watchdog"
|
"timeout": "watchdog"
|
||||||
}))
|
})))
|
||||||
.await;
|
.await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -331,7 +393,7 @@ async fn read(
|
||||||
// redis: redis::Client,
|
// redis: redis::Client,
|
||||||
mut state: Arc<AppState>,
|
mut state: Arc<AppState>,
|
||||||
mut receiver: SplitStream<WebSocket>,
|
mut receiver: SplitStream<WebSocket>,
|
||||||
tx: Sender<serde_json::Value>,
|
tx: Sender<TxControlMessage>,
|
||||||
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
|
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
|
||||||
last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>,
|
last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
|
@ -343,7 +405,10 @@ async fn read(
|
||||||
// Send back to client from services
|
// Send back to client from services
|
||||||
while let Ok(s_msg) = system_rx.recv().await {
|
while let Ok(s_msg) = system_rx.recv().await {
|
||||||
if convert_sys_msg_command(&s_msg).is_some()
|
if convert_sys_msg_command(&s_msg).is_some()
|
||||||
&& let Some(err) = tx_to_client.send(s_msg).await.err()
|
&& let Some(err) = tx_to_client
|
||||||
|
.send(TxControlMessage::Payload(s_msg))
|
||||||
|
.await
|
||||||
|
.err()
|
||||||
{
|
{
|
||||||
println!("[SYS] failed to send back to client: {err}");
|
println!("[SYS] failed to send back to client: {err}");
|
||||||
}
|
}
|
||||||
|
|
@ -355,6 +420,7 @@ async fn read(
|
||||||
Message::Text(t) => {
|
Message::Text(t) => {
|
||||||
let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?;
|
let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?;
|
||||||
let req_clone = req.clone();
|
let req_clone = req.clone();
|
||||||
|
info!("get msg: {}", req.type_w);
|
||||||
match req.type_w.as_str() {
|
match req.type_w.as_str() {
|
||||||
"recipe" if req.payload.is_some() => {
|
"recipe" if req.payload.is_some() => {
|
||||||
// guard expect value
|
// guard expect value
|
||||||
|
|
@ -438,7 +504,7 @@ async fn read(
|
||||||
error!("File corrupted, invalid json format");
|
error!("File corrupted, invalid json format");
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = tx.send(serde_json::json!({
|
let _ = tx.send(TxControlMessage::Payload(serde_json::json!({
|
||||||
"type": "notify",
|
"type": "notify",
|
||||||
"payload": {
|
"payload": {
|
||||||
"from": "system_tx",
|
"from": "system_tx",
|
||||||
|
|
@ -446,7 +512,7 @@ async fn read(
|
||||||
"msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version),
|
"msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version),
|
||||||
"to": ""
|
"to": ""
|
||||||
}
|
}
|
||||||
})).await;
|
}))).await;
|
||||||
|
|
||||||
return Err(e.into());
|
return Err(e.into());
|
||||||
}
|
}
|
||||||
|
|
@ -549,7 +615,7 @@ async fn read(
|
||||||
if let Some(_) = state.system_tx.send(p).err() {
|
if let Some(_) = state.system_tx.send(p).err() {
|
||||||
info!("failed to send command request");
|
info!("failed to send command request");
|
||||||
let _ = tx
|
let _ = tx
|
||||||
.send(serde_json::json!({
|
.send(TxControlMessage::Payload(serde_json::json!({
|
||||||
"type": "notify",
|
"type": "notify",
|
||||||
"payload": {
|
"payload": {
|
||||||
"from": "system_tx",
|
"from": "system_tx",
|
||||||
|
|
@ -557,7 +623,7 @@ async fn read(
|
||||||
"msg": "send request fail",
|
"msg": "send request fail",
|
||||||
"to": ""
|
"to": ""
|
||||||
}
|
}
|
||||||
}))
|
})))
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -652,31 +718,46 @@ async fn read(
|
||||||
|
|
||||||
async fn write(
|
async fn write(
|
||||||
mut sender: SplitSink<WebSocket, Message>,
|
mut sender: SplitSink<WebSocket, Message>,
|
||||||
mut rx: Receiver<serde_json::Value>,
|
mut rx: Receiver<TxControlMessage>,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
while let Some(res) = rx.recv().await {
|
while let Some(res) = rx.recv().await {
|
||||||
// force close
|
match res {
|
||||||
if let Some(force_timeout_by) = res.get("timeout")
|
TxControlMessage::Payload(res) => {
|
||||||
&& let Some(from_who) = force_timeout_by.as_str()
|
// force close
|
||||||
&& from_who.eq("watchdog")
|
if let Some(force_timeout_by) = res.get("timeout")
|
||||||
{
|
&& let Some(from_who) = force_timeout_by.as_str()
|
||||||
warn!("receive close from watchdog");
|
&& from_who.eq("watchdog")
|
||||||
let _ = sender.send(Message::Close(None)).await;
|
{
|
||||||
break;
|
warn!("receive close from watchdog");
|
||||||
|
let _ = sender.send(Message::Close(None)).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if let Some(res_n) = res.as_object()
|
||||||
|
// && let Some(res_payload) = res_n.get("payload")
|
||||||
|
// && let Some(res_payload_val) = res_payload.as_object()
|
||||||
|
// && let Some(recv_ident) = res_payload_val.get("to")
|
||||||
|
// && let Some(recv_ident_str) = recv_ident.as_str()
|
||||||
|
// {}
|
||||||
|
|
||||||
|
let payload_size = res.to_string().len();
|
||||||
|
|
||||||
|
if payload_size >= 100000 {
|
||||||
|
// large payload
|
||||||
|
warn!(
|
||||||
|
"sending large payload to client ... ({})",
|
||||||
|
res.to_string().len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = sender.send(res.to_string().into()).await;
|
||||||
|
}
|
||||||
|
TxControlMessage::CloseExist => {
|
||||||
|
let _ = sender.close().await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let payload_size = res.to_string().len();
|
|
||||||
|
|
||||||
if payload_size >= 100000 {
|
|
||||||
// large payload
|
|
||||||
warn!(
|
|
||||||
"sending large payload to client ... ({})",
|
|
||||||
res.to_string().len()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = sender.send(res.to_string().into()).await;
|
|
||||||
|
|
||||||
// リミットブレく - limito breaku!! (uncomment to slow down messages)
|
// リミットブレく - limito breaku!! (uncomment to slow down messages)
|
||||||
// let _ = tokio::time::sleep(Duration::from_millis(100)).await;
|
// let _ = tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
}
|
}
|
||||||
|
|
@ -721,7 +802,7 @@ fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt: i3
|
||||||
|
|
||||||
async fn throttle_send_recipe(
|
async fn throttle_send_recipe(
|
||||||
recipe: &Recipe,
|
recipe: &Recipe,
|
||||||
tx: &Sender<serde_json::Value>,
|
tx: &Sender<TxControlMessage>,
|
||||||
country: String,
|
country: String,
|
||||||
version: String,
|
version: String,
|
||||||
) {
|
) {
|
||||||
|
|
@ -748,7 +829,7 @@ async fn throttle_send_recipe(
|
||||||
let sid = ss.get_id();
|
let sid = ss.get_id();
|
||||||
info!("starting {sid}");
|
info!("starting {sid}");
|
||||||
|
|
||||||
if let Some(err) = tx.send(ss.as_msg()).await.err() {
|
if let Some(err) = tx.send(TxControlMessage::Payload(ss.as_msg())).await.err() {
|
||||||
println!("ERR: send tx error, {err:?}");
|
println!("ERR: send tx error, {err:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -758,7 +839,7 @@ async fn throttle_send_recipe(
|
||||||
let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec());
|
let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec());
|
||||||
|
|
||||||
// no validate
|
// no validate
|
||||||
if let Some(err) = tx.send(sda.as_msg()).await.err() {
|
if let Some(err) = tx.send(TxControlMessage::Payload(sda.as_msg())).await.err() {
|
||||||
println!("ERR: send tx error, {err:?}");
|
println!("ERR: send tx error, {err:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -770,7 +851,11 @@ async fn throttle_send_recipe(
|
||||||
|
|
||||||
let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec());
|
let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec());
|
||||||
|
|
||||||
if let Some(err) = tx.send(extra_matset.as_msg()).await.err() {
|
if let Some(err) = tx
|
||||||
|
.send(TxControlMessage::Payload(extra_matset.as_msg()))
|
||||||
|
.await
|
||||||
|
.err()
|
||||||
|
{
|
||||||
println!("ERR: send tx extra error: {err:?}");
|
println!("ERR: send tx extra error: {err:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -779,7 +864,11 @@ async fn throttle_send_recipe(
|
||||||
for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() {
|
for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() {
|
||||||
let curr_ch_id = format!("{mat_exid}_tl{index}");
|
let curr_ch_id = format!("{mat_exid}_tl{index}");
|
||||||
let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec());
|
let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec());
|
||||||
if let Some(err) = tx.send(extra_topplist.as_msg()).await.err() {
|
if let Some(err) = tx
|
||||||
|
.send(TxControlMessage::Payload(extra_topplist.as_msg()))
|
||||||
|
.await
|
||||||
|
.err()
|
||||||
|
{
|
||||||
println!("ERR: send tx extra2 error: {err:?}");
|
println!("ERR: send tx extra2 error: {err:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -788,7 +877,11 @@ async fn throttle_send_recipe(
|
||||||
for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() {
|
for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() {
|
||||||
let curr_ch_id = format!("{mat_exid}_tg{index}");
|
let curr_ch_id = format!("{mat_exid}_tg{index}");
|
||||||
let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec());
|
let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec());
|
||||||
if let Some(err) = tx.send(extra_toppgrp.as_msg()).await.err() {
|
if let Some(err) = tx
|
||||||
|
.send(TxControlMessage::Payload(extra_toppgrp.as_msg()))
|
||||||
|
.await
|
||||||
|
.err()
|
||||||
|
{
|
||||||
println!("ERR: send tx extra2 error: {err:?}");
|
println!("ERR: send tx extra2 error: {err:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -803,15 +896,32 @@ async fn throttle_send_recipe(
|
||||||
// return sid;
|
// return sid;
|
||||||
let end_msg = StreamDataEnd::new(&sid);
|
let end_msg = StreamDataEnd::new(&sid);
|
||||||
|
|
||||||
if let Some(err) = tx.send(end_msg.as_msg()).await.err() {
|
if let Some(err) = tx
|
||||||
|
.send(TxControlMessage::Payload(end_msg.as_msg()))
|
||||||
|
.await
|
||||||
|
.err()
|
||||||
|
{
|
||||||
println!("ERR: send tx error, {err:?}");
|
println!("ERR: send tx error, {err:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Hub {
|
||||||
|
pub clients: HashMap<String, Sender<TxControlMessage>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum TxControlMessage {
|
||||||
|
Payload(serde_json::Value),
|
||||||
|
CloseExist,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
dev_config: DevConfig,
|
dev_config: DevConfig,
|
||||||
redis_cli: redis::Client,
|
redis_cli: redis::Client,
|
||||||
system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
|
system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
|
||||||
|
// saved client uid:client uuid
|
||||||
|
connectors_mapping: Arc<Mutex<Hub>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
|
|
@ -831,6 +941,9 @@ impl AppState {
|
||||||
dev_config,
|
dev_config,
|
||||||
redis_cli,
|
redis_cli,
|
||||||
system_tx,
|
system_tx,
|
||||||
|
connectors_mapping: Arc::new(Mutex::new(Hub {
|
||||||
|
clients: HashMap::new(),
|
||||||
|
})),
|
||||||
});
|
});
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|
@ -988,6 +1101,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
// .route("/sessionLogin", post(session_login))
|
// .route("/sessionLogin", post(session_login))
|
||||||
.route("/syscb", post(post_from_other_system))
|
.route("/syscb", post(post_from_other_system))
|
||||||
|
.route("/regas", post(request_api_session_key))
|
||||||
.nest("/recipe", rp_router)
|
.nest("/recipe", rp_router)
|
||||||
.nest("/docs", doc_router)
|
.nest("/docs", doc_router)
|
||||||
.with_state(app_state);
|
.with_state(app_state);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue