feat: change uid check

- change uid checker due to limitation of header
- refactor codes

Signed-off-by: Pakin <pakin.t@forth.co.th>
This commit is contained in:
Pakin 2026-04-28 16:43:22 +07:00
parent 819bd08bc3
commit da956d39a7
16 changed files with 1398 additions and 1451 deletions

276
src/app.rs Normal file
View file

@ -0,0 +1,276 @@
use crate::websocket::{core::*, model::*};
use axum::{
Router,
routing::{get, post},
serve::ListenerExt,
};
use log::{error, info};
use redis::TypedCommands;
use std::{
collections::{HashMap, VecDeque},
env,
sync::Arc,
};
use tokio::sync::{Mutex, mpsc::Sender};
#[derive(Clone)]
pub struct Hub {
pub clients: HashMap<String, Sender<TxControlMessage>>,
}
#[derive(Clone)]
pub struct DevConfig {
pub api_key: String,
pub api_domain: String,
pub api_recipe_service: String,
pub api_redis_url: String,
pub api_resolver: String,
}
impl DevConfig {
pub fn new(
key: String,
domain: String,
rp_service: String,
api_redis_url: String,
api_resolver: String,
) -> DevConfig {
DevConfig {
api_key: key,
api_domain: domain,
api_recipe_service: rp_service,
api_redis_url,
api_resolver,
}
}
pub fn get_recipe_url(&self) -> String {
format!("{}{}", self.api_domain, self.api_recipe_service)
}
pub fn get_file_from_recipe_repo(&self, path: String) -> String {
format!("{}/checkout?path={}", self.get_recipe_url(), path)
}
pub fn get_api_header(&self) -> (String, String) {
("X-API-Key".to_string(), self.api_key.clone())
}
pub fn get_yuki_resolver(&self) -> String {
format!("{}/resolve", self.api_resolver)
}
}
pub struct AppState {
pub dev_config: DevConfig,
pub redis_cli: redis::Client,
pub system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
// saved client uid:client uuid
pub connectors_mapping: Arc<Mutex<Hub>>,
}
impl AppState {
pub fn get_cfg(&self) -> DevConfig {
self.dev_config.clone()
}
pub async fn new(
dev_config: DevConfig,
redis_cli: redis::Client,
system_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
) -> Arc<AppState> {
let redis_cli_clone = redis_cli.clone();
let tx_new = system_tx.clone();
let result = Arc::new(AppState {
dev_config,
redis_cli,
system_tx,
connectors_mapping: Arc::new(Mutex::new(Hub {
clients: HashMap::new(),
})),
});
tokio::spawn(async move {
let mut lredis = redis_cli_clone.clone();
let current_queue: crossbeam_queue::ArrayQueue<CommandRequestPayload> =
crossbeam_queue::ArrayQueue::new(1);
let mut pending_command: VecDeque<CommandRequestPayload> = VecDeque::new();
let mut check_available_path = String::new();
loop {
if let Ok(rmsg) = system_rx.recv().await {
let sys_msg = crate::websocket::helper::convert_sys_msg_command(&rmsg);
// add queue process
let command_req: CommandRequestPayload = match serde_json::from_value(rmsg) {
Ok(cmd) => cmd,
Err(e) => {
if sys_msg.is_none() {
// maybe error
error!("error deserialize: {e:?} ---> Skip");
}
continue;
} // reject
};
info!("get cmd: {}", command_req.srv_name);
if let Err(fail_payload) = current_queue.push(command_req.clone()) {
if pending_command.len() < 10 {
pending_command.push_back(fail_payload)
} else {
let user_name = fail_payload
.user_info
.get("displayName")
.unwrap_or_default();
let _ = tx_new.send(serde_json::json!({
"type": "notify",
"payload": {
"from": "system_tx",
"msg": "request queue full, try again later",
"level": "ERR",
"to": user_name,
}
}));
}
} else {
// set check to latest push to queue ok
check_available_path = format!("{}/status", command_req.srv_name);
info!("checking {check_available_path}");
}
}
// send process
if let Ok(Some(status)) = lredis.get(&check_available_path) {
info!("status: {status}");
match status.as_str() {
"ok" | "OK" | "Ok" => {
info!("queue: {}", current_queue.len());
//
if current_queue.is_full()
&& let Some(cmd) = current_queue.pop()
{
// get one
let channel = format!("{}/job", cmd.srv_name);
info!("channel job: {channel}");
info!("job: {cmd:?}");
let prep = serde_json::json!({
"type": "command",
"payload": cmd
});
let result_pub = lredis.publish(
channel,
serde_json::to_string(&prep).unwrap_or("{}".to_string()),
);
info!("published: {result_pub:?}");
// queue next
if let Some(next_cmd) = pending_command.pop_front() {
check_available_path = format!("{}/status", next_cmd.srv_name);
// ignore result
let _ = current_queue.push(next_cmd);
} else {
check_available_path = String::new();
}
} else if current_queue.is_empty() {
check_available_path = String::new();
}
}
_ => {}
}
} else if current_queue.is_empty()
&& let Some(next_cmd) = pending_command.pop_front()
{
// case empty queue, fetch next
check_available_path = format!("{}/status", next_cmd.srv_name);
// ignore result
let _ = current_queue.push(next_cmd);
}
}
});
result
}
}
pub async fn invoke_checkout_request(
config: DevConfig,
path: String,
) -> Result<String, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let req_path = config.get_file_from_recipe_repo(path);
// println!("dbg: {req_path}");
let res = client.get(req_path).send().await?;
match res.text().await {
Ok(raw) => Ok(raw),
Err(e) => Err(format!("{e}").into()),
}
}
pub async fn create_recipe_repo_router() -> Router<Arc<AppState>> {
Router::new().route("/ws", get(crate::websocket::handler::websocket_handler))
}
pub async fn initialize() -> Result<(), Box<dyn std::error::Error>> {
let server_port = env::var("PORT").unwrap_or("36579".to_string());
let api_key = env::var("DEV_API_KEY").expect("no api key");
let api_domain = env::var("DEV_API_DOMAIN").expect("no domain");
let api_recipe_service = env::var("DEV_API_RECIPE_SERVICE").expect("no service");
let api_redis = env::var("DEV_API_REDIS").unwrap_or("0.0.0.0".to_string());
let api_redis_port = env::var("DEV_API_REDIS_PORT").unwrap_or("6379".to_string());
let api_resolver = env::var("RESOLVER_SERVICE_URL").expect("no available resolver");
let dev_cfg = crate::app::DevConfig::new(
api_key,
api_domain,
api_recipe_service,
format!("redis://{api_redis}:{api_redis_port}"),
api_resolver,
);
// test_send(dev_cfg).await?;
//
let redis_cli = redis::Client::open(dev_cfg.api_redis_url.clone())?;
let (sys_tx, sys_rx) = tokio::sync::broadcast::channel::<serde_json::Value>(16);
let app_state = AppState::new(dev_cfg, redis_cli, sys_tx, sys_rx).await;
let rp_router = create_recipe_repo_router().await;
// let doc_router = create_tx_patcher_route().await;
let app = Router::new()
// .route("/sessionLogin", post(session_login))
.route(
"/syscb",
post(crate::websocket::handler::post_from_other_system),
)
// .route("/regas", post(request_api_session_key))
.nest("/recipe", rp_router)
// .nest("/docs", doc_router)
.with_state(app_state);
// feature: no delay, full throttle
let nodelay_listener = || async {
tokio::net::TcpListener::bind(format!("0.0.0.0:{server_port}"))
.await
.unwrap()
.tap_io(|tcp_stream| {
if let Err(err) = tcp_stream.set_nodelay(true) {
error!("failed to set TCP_NODELAY on incoming connection: {err:#?}");
}
})
};
axum::serve(nodelay_listener().await, app).await?;
Ok(())
}

View file

@ -1,345 +0,0 @@
use std::{collections::HashMap, fs::File};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use celes::Country;
use redis::TypedCommands;
use crate::AppState;
#[allow(non_snake_case)]
#[derive(Serialize, Deserialize, Clone)]
pub struct CountryInfo {
image: String,
Brand: String,
Country: String,
VendingClass: String,
Machinecompatible: String,
MateriallistProfile: Vec<CountryInfoProfileDetail>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct CountryInfoProfileDetail {
json: String,
img: String,
desc: String,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
impl CountryInfo {
pub fn new(country_code: String, brand: Option<String>) -> CountryInfo {
let country = match Country::from_alpha3(&country_code.clone()) {
Ok(c) => c,
Err(_) => {
if country_code.eq("dubai") {
Country::the_united_arab_emirates()
} else {
Country::thailand()
}
}
};
CountryInfo {
image: format!("taobin_project/logo/{country_code}_plate.png"),
Brand: brand.unwrap_or("".to_string()),
Country: country.long_name.to_string(),
VendingClass: String::from("coffeethai02"),
Machinecompatible: String::from("GEN2 and GEN32"),
MateriallistProfile: vec![CountryInfoProfileDetail {
json: String::from("vending_setting_and_profile_v1.json"),
img: String::from("vending_setting_and_profile_v1.png"),
desc: String::new(),
extra: HashMap::default(),
}],
extra: HashMap::default(),
}
}
}
async fn get_root_files(state: AppState) -> Result<Value, Box<dyn std::error::Error>> {
let api_header = state.get_cfg().get_api_header();
let mut ret_result = serde_json::Value::Null;
let client = reqwest::Client::new();
let res = client
.get("http://localhost:36584/checkout?path=")
.header(api_header.0, api_header.1)
.send()
.await;
match res {
Ok(res) => {
if let Some(ct) = res.headers().get("content-type")
&& ct.eq("application/json")
{
let raw = res.text().await;
if let Ok(raw) = raw {
let result: serde_json::Value =
serde_json::from_str(&raw).unwrap_or(serde_json::Value::Null);
let mut redis_client = state.clone().redis_cli;
let _ = redis_client.set("root_repo", result.to_string());
ret_result = result.clone();
println!("setup next");
tokio::spawn(async move {
let s1 = setup_after_get_root(state.clone(), result.clone())
.await
.ok();
println!("checkpoint 1: {}", s1.is_some());
if let Some((country_with_version, country_mapping)) = s1 {
println!("entries: {}", country_with_version.len());
let _ = get_all_file_path_of_legit_country(
state.clone(),
country_with_version,
country_mapping,
)
.await;
}
});
}
}
}
Err(e) => {
println!("Error on root fetch: {e}");
}
}
Ok(ret_result)
}
async fn setup_after_get_root(
state: AppState,
roots: serde_json::Value,
) -> Result<(Vec<String>, HashMap<String, String>), Box<dyn std::error::Error>> {
let mut legit_country_with_version = Vec::new();
let mut country_version_mapping = HashMap::new();
if let Some(map) = roots.as_object()
&& let Some(res) = map.get("result")
{
let fds: Vec<String> = res
.as_array()
.unwrap_or(&Vec::new())
.iter()
.map(|x| x.as_str().unwrap_or("").to_string())
.collect();
println!("pre_loop: {fds:?}");
// TODO: build in pattern `<country_name>/version`
// if get response ok, save
// NOTE: filter country
let api_header = state.get_cfg().get_api_header();
for fd in fds {
println!("checking {fd}");
// try GET
let client = reqwest::Client::new();
let res = client
.get(format!("http://localhost:36584/checkout?path={fd}/version"))
.header(api_header.clone().0, api_header.clone().1)
.send()
.await;
if let Ok(r) = res
&& let Some(ct) = r.headers().get("content-type")
&& r.status().eq(&reqwest::StatusCode::OK)
&& ct.eq("application/json")
&& let Ok(txt) = r.text().await
{
println!("{fd}.version = {txt}");
//
let vres: HashMap<String, String> = serde_json::from_str(&txt).unwrap();
let vv = vres
.get("result")
.map(|x| x.to_string())
.unwrap_or("".to_string());
// get version of country
let mut rcli = state.clone().redis_cli;
let _ = rcli.set(format!("{fd}.version"), vv.clone());
// generate all file paths
legit_country_with_version.push(fd.clone());
country_version_mapping.insert(fd.clone(), vv.clone());
}
}
}
Ok((legit_country_with_version, country_version_mapping))
}
async fn get_all_file_path_of_legit_country(
state: AppState,
legit_countries: Vec<String>,
country_mapping: HashMap<String, String>,
) -> Result<(), Box<dyn std::error::Error>> {
let api_header = state.get_cfg().get_api_header();
// save all entries of each country
for country in legit_countries {
let client = reqwest::Client::new();
let res = client
.get(format!("http://localhost:36584/checkout?path={country}"))
.header(api_header.clone().0, api_header.clone().1)
.send()
.await;
if let Ok(r) = res
&& let Some(ct) = r.headers().get("content-type")
&& r.status().eq(&reqwest::StatusCode::OK)
&& ct.eq("application/json")
&& let Ok(txt) = r.text().await
{
// get version of country & persist save
let mut rcli = state.clone().redis_cli;
let _ = rcli.set(country.clone(), txt.clone());
// generate all file paths
println!("{country} ready!");
let files: HashMap<String, Vec<String>> =
serde_json::from_str(&txt.clone()).unwrap_or(HashMap::new());
// stream content
let _ = rcli.publish(
"recipe_files_by_country",
json!({country.clone() : files}).to_string(),
);
if let Some(fl) = files.get("result") {
let has_info = fl.contains(&".info.json".to_string());
println!("{country} has info: {has_info}");
// read version
let current_latest_version = country_mapping
.get(&country)
.map(|x| x.to_string())
.unwrap_or("unknown".to_string());
let latest_version_file: Vec<String> = fl
.iter()
.filter(|x| x.contains(&current_latest_version))
.map(|x| x.to_string())
.collect();
if !has_info {
// generate info for country
let _ = generate_country_info_default(state.clone(), country.clone()).await;
} else {
let _ = fetch_country_info(state.clone(), country.clone()).await;
}
// do fetch latest version into redis
if let Some(single) = latest_version_file.first() {
let res_c = client
.get(format!(
"http://localhost:36584/checkout?path={}/{single}",
country.clone()
))
.header(api_header.clone().0, api_header.clone().1)
.send()
.await;
if let Ok(latest_raw) = res_c
&& let Ok(latest_raw_txt) = latest_raw.text().await
{
println!("cached {single}");
let _ = rcli.set(
format!("{}/{}", country.clone(), single.clone()),
latest_raw_txt,
);
}
}
}
}
}
Ok(())
}
async fn generate_country_info_default(
state: AppState,
cc: String,
) -> Result<(), Box<dyn std::error::Error>> {
let country_info = match cc.as_str() {
"sgp" | "dubai" => CountryInfo::new(cc.clone(), Some("WhatTheCup".to_string())),
"gbr" | "aus" | "hkg" | "rou" | "lva" | "est" | "etu" => {
CountryInfo::new(cc.clone(), Some("Flying Turtle".to_string()))
}
_ => CountryInfo::new(cc.clone(), Some("Taobin".to_string())),
};
// save country info
let mut rcli = state.clone().redis_cli;
let _ = rcli.set(format!("{cc}.info"), serde_json::to_string(&country_info)?);
// save local
let json = serde_json::to_string(&country_info)?;
let json2: serde_json::Value = serde_json::from_str(&json)?;
let writer = File::create(format!(".info.{cc}.json")).unwrap();
let _ = serde_json::to_writer_pretty(writer, &json2);
Ok(())
}
async fn fetch_country_info(state: AppState, cc: String) -> Result<(), Box<dyn std::error::Error>> {
let api_header = state.get_cfg().get_api_header();
let client = reqwest::Client::new();
let res = client
.get(
state
.get_cfg()
.get_file_from_recipe_repo(format!("{cc}/.info.json")),
)
.header(api_header.clone().0, api_header.clone().1)
.send()
.await;
if let Ok(r) = res
&& let Some(ct) = r.headers().get("content-type")
&& r.status().eq(&reqwest::StatusCode::OK)
&& ct.eq("application/json")
&& let Ok(txt) = r.text().await
{
let mut rcli = state.clone().redis_cli;
let info: CountryInfo =
serde_json::from_str(&txt.clone()).unwrap_or(CountryInfo::new(cc.clone(), None));
let _ = rcli.set(format!("{cc}.info"), serde_json::to_string(&info)?);
let json = serde_json::to_string(&info)?;
let json2: serde_json::Value = serde_json::from_str(&json)?;
let writer = File::create(format!(".info.{cc}.json")).unwrap();
let _ = serde_json::to_writer_pretty(writer, &json2);
}
Ok(())
}
pub async fn cold_start_process(state: AppState) -> Result<(), Box<dyn std::error::Error>> {
let ostate = state.clone();
println!("starting cold process");
let _ = tokio::spawn(async move {
match get_root_files(ostate).await {
Ok(res) => {
println!("cold start ok, {}", res);
}
Err(e) => {
println!("cold start error: {e}");
}
}
})
.await;
Ok(())
}

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,3 @@
use rayon::iter::Either;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
@ -62,6 +61,7 @@ pub struct StreamDataChunk<T> {
pub start_idx: usize,
/// Chunked data which splited into N items per chunk
pub data: Vec<T>,
uid: String,
}
impl<T> IntoStreamMessage for StreamDataChunk<T>
@ -86,11 +86,12 @@ impl<T> StreamDataChunk<T>
where
T: Serialize,
{
pub fn new(sid: &str, start_idx: usize, data: Vec<T>) -> Self {
pub fn new(sid: &str, start_idx: usize, data: Vec<T>, uid: String) -> Self {
Self {
stream_id: sid.to_string(),
start_idx,
data,
uid,
}
}

22
src/websocket/core.rs Normal file
View file

@ -0,0 +1,22 @@
use std::time::Duration;
/// CONFIG: chunk size for each payload
///
/// note: using in sending recipe
pub const CHUNK_SIZE: usize = 5;
/// CONFIG: default timeout for each socket connection
pub const TIMEOUT: Duration = Duration::from_secs(60 * 15);
#[derive(Clone)]
pub enum TxControlMessage {
Payload(serde_json::Value),
CloseExist,
}
pub enum UserWebSocketAuthState {
UNAUTHORIZED,
AUTHORIZED,
}
pub type WebsocketMessageResult = Result<(), Box<dyn std::error::Error + Send + Sync>>;

141
src/websocket/handler.rs Normal file
View file

@ -0,0 +1,141 @@
use axum::{
Json,
extract::{State, WebSocketUpgrade, ws::WebSocket},
response::IntoResponse,
};
use futures::StreamExt;
use log::{info, warn};
use redis::TypedCommands;
use std::sync::Arc;
use tokio::{sync::Mutex, sync::mpsc, time::Instant};
use uuid::Uuid;
use super::{core::*, model::*};
use crate::app::{AppState, Hub};
pub async fn post_from_other_system(
State(state): State<Arc<AppState>>,
Json(msg): Json<SysMessage>,
) -> impl IntoResponse {
let sys_payload = match serde_json::to_value(&msg) {
Ok(s) => s,
Err(_) => {
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"cannot create payload",
)
.into_response();
}
};
match state.system_tx.send(sys_payload) {
Ok(_) => (axum::http::StatusCode::OK, "").into_response(),
Err(_) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "send fail").into_response(),
}
}
pub async fn request_api_session_key(
State(state): State<Arc<AppState>>,
Json(msg): Json<ApiSessionRequest>,
) -> impl IntoResponse {
let mut rcl = state.redis_cli.clone();
// gen key
let generated = uuid::Uuid::new_v4().to_string().replace("-", "");
if let Some(uid_n) = msg.uid.as_str() {
let saved_key = format!("{uid_n}-ak");
let _ = rcl.set_ex(saved_key, generated.clone(), 3600);
}
(axum::http::StatusCode::OK, generated).into_response()
}
/// Main websocket handler
pub async fn websocket_handler(
State(state): State<Arc<AppState>>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
let state_clone = Arc::clone(&state);
let hub_clone = Arc::clone(&state_clone.connectors_mapping);
// let mut uid_n = String::new();
// if let Some(uid) = headers.get("x-auth-uid") {
// let uid_from_web = uid.to_str().unwrap_or_default().to_string();
// uid_n = uid_from_web;
// info!("user connect {uid_n}");
// }
// if uid_n.is_empty() {
// return (axum::http::StatusCode::BAD_REQUEST, "").into_response();
// }
ws.on_failed_upgrade(|error| println!("Error upgrading websocket: {}", error))
.on_upgrade(async |s| handle_socket(s, state_clone, hub_clone).await.unwrap_or(()))
}
async fn handle_socket(
socket: WebSocket,
state: Arc<AppState>,
hub: Arc<Mutex<Hub>>,
) -> Result<(), Box<dyn std::error::Error>> {
let (sender, receiver) = socket.split();
// internal channel
let (tx, rx) = mpsc::channel::<TxControlMessage>(2);
// TODO: change auth method from header to delay auth message timeout within 5 secs
//
// {
// // register & kick old tx
// let mut h = hub.try_lock().unwrap();
// if let Some(old_tx) = h.clients.insert(uid_auth_session.to_string(), tx.clone()) {
// warn!("disconnect old connection");
// let _ = old_tx.send(TxControlMessage::CloseExist);
// }
// }
// spawn as unknown
let user = Arc::new(Mutex::new(String::from(format!(
"temp-{}",
Uuid::new_v4().to_string()
))));
let temp_session = user.try_lock().unwrap().to_string();
info!("{} connected", temp_session);
{
let mut h = hub.try_lock().unwrap();
h.clients.insert(temp_session.clone(), tx.clone());
}
let user_sys_rx = state.system_tx.subscribe();
let last_seen = Arc::new(Mutex::new(Instant::now()));
let reader_last_seen = last_seen.clone();
let watchdog_last_seen = last_seen.clone();
let sender = tokio::spawn(super::rw::write(sender, rx, user.clone()));
let reader = tokio::spawn(super::rw::read(
state,
receiver,
tx.clone(),
user_sys_rx,
reader_last_seen,
user.clone(),
hub.clone(),
));
let watchdog = super::tasks::watchdog::get_watchdog_task(
tx,
watchdog_last_seen,
user.clone(),
hub.clone(),
);
let _ = tokio::join!(reader, sender, watchdog);
Ok(())
}

66
src/websocket/helper.rs Normal file
View file

@ -0,0 +1,66 @@
use super::model::*;
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use redis::{TypedCommands, cmd};
#[deprecated]
pub async fn send_close_message(mut socket: WebSocket, code: u16, reason: &str) {
_ = socket
.send(Message::Close(Some(CloseFrame {
code,
reason: reason.into(),
})))
.await;
}
#[deprecated]
pub async fn fetch_content_from_redis(redis: redis::Client, key: &str) -> Result<String, String> {
let mut rcli = redis.clone();
match rcli.get(key) {
Ok(s) => {
if let Some(res) = s {
Ok(res)
} else {
Err(format!("result error from key: {key}"))
}
}
Err(e) => Err(format!("redis get failed: {e}")),
}
}
pub async fn fetch_content_from_redis_byte(
redis: redis::Client,
key: &str,
) -> Result<Vec<u8>, String> {
let mut conn = match redis.get_connection() {
Ok(cnn) => cnn,
Err(e) => {
println!("get connection fail, {e}");
return Ok(vec![]);
}
};
let res = cmd("GET").arg(key).query::<Vec<u8>>(&mut conn);
match res {
Ok(res) => Ok(res),
Err(e) => {
println!("get fail, {e}");
return Ok(vec![]);
}
}
}
#[deprecated]
pub fn convert_ack_command(cmd_req: &serde_json::Value) -> Option<CommandRequestPayload> {
match serde_json::from_value(cmd_req.clone()) {
Ok(req) => Some(req),
Err(_) => None,
}
}
pub fn convert_sys_msg_command(msg: &serde_json::Value) -> Option<SysMessage> {
match serde_json::from_value(msg.clone()) {
Ok(req) => Some(req),
Err(_) => None,
}
}

6
src/websocket/mod.rs Normal file
View file

@ -0,0 +1,6 @@
pub mod core;
pub mod handler;
pub mod helper;
pub mod model;
mod rw;
mod tasks;

86
src/websocket/model.rs Normal file
View file

@ -0,0 +1,86 @@
use serde::{Deserialize, Serialize};
/// system message to send back to client, this may be called from other services
#[derive(Debug, Serialize, Deserialize)]
pub struct SysMessage {
#[serde(rename = "type")]
pub stype: String,
pub payload: serde_json::Value,
}
/// Request to generate api rotated key
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiSessionRequest {
// uid from user login
pub uid: serde_json::Value,
//
pub timestamp: serde_json::Value,
}
/// General message struct from websocket request
#[derive(Serialize, Deserialize, Clone)]
pub struct WebsocketMessageRequest {
#[serde(rename = "type")]
pub type_w: String,
pub payload: Option<serde_json::Value>,
}
/// Recipe request payload struct
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RecipeRequestPayload {
/// For validate request is acceptable
pub auth: Option<String>,
/// Only grep partial of file, will be sent with json patch
pub partial: Option<bool>,
/// Country of recipe
pub country: String,
/// version of recipe
pub version: i64,
/// Extended infos, required parameters or unimplemented fields in the current struct. Expected pattern `<key1>=<val1>,<key2>=<val2>,...`
pub parameters: Option<String>,
}
/// Command request for external services
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CommandRequestPayload {
/// User info expect at least id, token, name
pub user_info: serde_json::Value,
/// Target service
pub srv_name: String,
/// Values
pub values: serde_json::Value,
}
/// For logging user's action
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LogReportPayload {
// expect either `email` or `unknown`
pub user: String,
pub action: String,
// expect either country name or `unknown dep`
pub country: String,
pub values: serde_json::Value,
}
/// Message for saving recipe
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SaveRecipePayload {
pub user: String,
pub country: String,
pub values: serde_json::Value,
}
/// Message for authentication before use m2 service
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AuthPayload {
pub user: AuthUserField,
}
/// Internal field for auth payload
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AuthUserField {
pub uid: String,
pub name: String,
pub email: String,
pub permissions: String,
}

216
src/websocket/rw.rs Normal file
View file

@ -0,0 +1,216 @@
use super::{core::*, helper::*, model::*};
use crate::{app::*, websocket::tasks};
use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket};
use futures::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use log::{error, info, warn};
use tokio::{
sync::{
Mutex,
mpsc::{Receiver, Sender},
},
time::Instant,
};
pub async fn read(
// redis: redis::Client,
state: Arc<AppState>,
mut receiver: SplitStream<WebSocket>,
tx: Sender<TxControlMessage>,
mut system_rx: tokio::sync::broadcast::Receiver<serde_json::Value>,
last_seen: Arc<Mutex<Instant>>, // cmd_atom: crossbeam_queue::ArrayQueue<CommandRequestPayload>,
uid: Arc<Mutex<String>>,
hub: Arc<Mutex<Hub>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let redis = state.redis_cli.clone();
let config = state.dev_config.clone();
let tx_to_client = tx.clone();
tokio::spawn(async move {
// Send back to client from services
while let Ok(s_msg) = system_rx.recv().await {
if convert_sys_msg_command(&s_msg).is_some()
&& let Some(err) = tx_to_client
.send(TxControlMessage::Payload(s_msg))
.await
.err()
{
println!("[SYS] failed to send back to client: {err}");
}
}
});
let uid_clone = uid.clone();
while let Some(Ok(msg)) = receiver.next().await {
match msg {
Message::Text(t) => {
let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?;
info!("get msg: {}", req.type_w);
match req.type_w.as_str() {
"recipe" if req.payload.is_some() => {
tasks::recipe::handle_recipe_request(
config.clone(),
redis.clone(),
tx.clone(),
req,
uid_clone.clone(),
)
.await?
}
"command" if req.payload.is_some() => {
tasks::command::handle_command_request(state.clone(), tx.clone(), req)
.await?;
}
"heartbeat" => {
*last_seen.lock().await = Instant::now();
}
"sheet" if req.payload.is_some() => {
if tasks::sheet::handle_sheet_request(redis.clone(), req)
.await
.is_err()
{
continue;
}
}
"log_report" if let Some(log_payload) = req.payload => {
let log_report_payload: LogReportPayload =
match serde_json::from_value(log_payload) {
Ok(lreq) => lreq,
Err(e) => {
error!("error deserialize body log request: {e:?} ---> Skip");
continue;
}
};
// generate timestamp
//
let now = Instant::now();
}
"save_recipe" if let Some(save_recipe_payload) = req.payload => {
let save_recipe_payload: SaveRecipePayload =
match serde_json::from_value(save_recipe_payload) {
Ok(lreq) => lreq,
Err(e) => {
error!("error deserialize body log request: {e:?} ---> Skip");
continue;
}
};
}
"auth" if req.payload.is_some() => {
tasks::auth::handle_auth_request(
state.clone(),
tx.clone(),
req,
hub.clone(),
uid_clone.clone(),
)
.await?;
}
_ => {
// not implemented
}
}
}
Message::Ping(_) => {
*last_seen.lock().await = Instant::now();
}
Message::Close(_) => {
info!("get close message");
// remove current uid
{
let mut h = hub.try_lock().unwrap();
let curr_user = uid.try_lock().unwrap().to_string();
if let Some(ent) = h.clients.remove_entry(&curr_user) {
let curr_connection = ent.1;
let new_uid = format!("temp-{}-wait-clean", ent.0).to_string();
if let Some(old_tx) = h.clients.insert(new_uid, curr_connection) {
let _ = old_tx.send(TxControlMessage::CloseExist);
}
}
let mut ouid = uid.try_lock().unwrap();
*ouid = format!("temp-{ouid}-wait-clean").to_string();
}
// client disconnect by themselves
let _ = tx
.send(TxControlMessage::Payload(serde_json::json!({
"timeout": "disconnection"
})))
.await;
}
_ => {
// unhanled, ignore
}
}
}
Ok(())
}
pub async fn write(
mut sender: SplitSink<WebSocket, Message>,
mut rx: Receiver<TxControlMessage>,
uid: Arc<Mutex<String>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
while let Some(res) = rx.recv().await {
match res {
TxControlMessage::Payload(res) => {
// force close
if let Some(force_timeout_by) = res.get("timeout")
&& let Some(from_who) = force_timeout_by.as_str()
&& (from_who.eq("watchdog") || from_who.eq("disconnection"))
{
warn!("receive close from {from_who}");
if from_who.eq("disconnection") {
let _ = sender.close().await;
} else {
let _ = sender.send(Message::Close(None)).await;
}
break;
}
let current_uid = uid.try_lock().unwrap();
if let Some(res_n) = res.as_object()
&& let Some(res_payload) = res_n.get("payload")
&& let Some(res_payload_val) = res_payload.as_object()
&& let Some(recv_ident) = res_payload_val.get("to")
&& let Some(recv_ident_str) = recv_ident.as_str()
&& (current_uid.to_string().eq(recv_ident_str)
|| current_uid.to_string().eq("*"))
{
let payload_size = res.to_string().len();
if payload_size >= 100000 {
// large payload
warn!(
"sending large payload to client ... ({})",
res.to_string().len()
);
}
let _ = sender.send(res.to_string().into()).await;
}
}
TxControlMessage::CloseExist => {
let _ = sender.close().await;
break;
}
}
// リミットブレく - limito breaku!! (uncomment to slow down messages)
// let _ = tokio::time::sleep(Duration::from_millis(100)).await;
}
Ok(())
}

View file

@ -0,0 +1,70 @@
use crate::app::*;
use crate::websocket::{core::*, model::*};
use log::{error, info, warn};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc::Sender};
/// Handle request of command type from websocket (read)
pub async fn handle_auth_request(
state: Arc<AppState>,
tx: Sender<TxControlMessage>,
req: WebsocketMessageRequest,
hub: Arc<Mutex<Hub>>,
curr_uid: Arc<Mutex<String>>,
) -> WebsocketMessageResult {
// do command send to other services
// // guard expect value
let auth_request: AuthPayload = match serde_json::from_value(req.payload.unwrap()) {
Ok(areq) => areq,
Err(e) => {
error!("error body auth: {e:?}");
return Err(format!("unexpected auth: {e:?}").into());
}
};
let new_uid = auth_request.user.uid;
if !new_uid.is_empty() {
let old_uid = curr_uid.try_lock().unwrap().to_string();
let mut h = hub.try_lock().unwrap();
if let Some(ent) = h.clients.remove_entry(&old_uid) {
let curr_connection = ent.1;
// close all existed temp & actual uid
// case auth success but already have some tx left
if let Some(old_tx) = h.clients.insert(new_uid.clone(), curr_connection) {
warn!("disconnecting old connection");
let _ = old_tx.send(TxControlMessage::CloseExist);
}
info!("re-new auth successful");
}
{
let mut ouid = curr_uid.try_lock().unwrap();
*ouid = new_uid.clone();
}
}
// TODO
// - Queue requests
// - Send if service available
// if let Some(_) = state.system_tx.send(p).err() {
// info!("failed to send command request");
// let _ = tx
// .send(TxControlMessage::Payload(serde_json::json!({
// "type": "notify",
// "payload": {
// "from": "system_tx",
// "level": "error",
// "msg": "send request fail",
// "to": ""
// }
// })))
// .await;
// }
Ok(())
}

View file

@ -0,0 +1,38 @@
use crate::app::*;
use crate::websocket::{core::*, model::*};
use log::info;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
/// Handle request of command type from websocket (read)
pub async fn handle_command_request(
state: Arc<AppState>,
tx: Sender<TxControlMessage>,
req: WebsocketMessageRequest,
) -> WebsocketMessageResult {
// do command send to other services
// // guard expect value
let p = req.payload.unwrap();
info!("get command request");
// TODO
// - Queue requests
// - Send if service available
if let Some(_) = state.system_tx.send(p).err() {
info!("failed to send command request");
let _ = tx
.send(TxControlMessage::Payload(serde_json::json!({
"type": "notify",
"payload": {
"from": "system_tx",
"level": "error",
"msg": "send request fail",
"to": ""
}
})))
.await;
}
Ok(())
}

View file

@ -0,0 +1,5 @@
pub mod auth;
pub mod command;
pub mod recipe;
pub mod sheet;
pub mod watchdog;

View file

@ -0,0 +1,355 @@
use crate::app::*;
use crate::stream::model::{
IntoStreamMessage, StreamDataChunk, StreamDataEnd, StreamDataExtra, StreamDataStart,
};
use crate::websocket::{core::*, helper::*, model::*};
use std::{fs::File, io::Read, path::PathBuf, sync::Arc};
use async_compression::tokio::bufread::BrotliDecoder;
use axum::extract::ws::{Message, WebSocket};
use futures::{
SinkExt, StreamExt,
stream::{SplitSink, SplitStream},
};
use libtbr::models::recipe::{MaterialSetting, Recipe, Recipe01};
use log::{error, info, warn};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use redis::{self, TypedCommands};
use tokio::{
io::{AsyncReadExt, BufReader},
sync::{
Mutex,
mpsc::{Receiver, Sender},
},
time::Instant,
};
pub fn is_req_patch(param: &RecipeRequestPayload) -> bool {
param.version != -1 && param.partial.is_some() && param.partial.unwrap()
}
pub fn get_local_file(filename: String) -> Result<File, std::io::Error> {
File::open(PathBuf::from(filename))
}
pub fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt: i32) -> String {
if is_patch {
format!("stx_{country}_{version}.json")
} else {
match retry_cnt {
1 => {
format!("master:{country}/coffeethai02_{version}_{country}.json")
}
2 => {
format!("master:{country}/coffeethai02_{version}.json")
}
3 => {
// do checkout
format!("{country}/coffeethai02_{version}_{country}.json")
}
4 => {
// do checkout
format!("{country}/coffeethai02_{version}.json")
}
5 => {
// checkout case premium
format!("{country}/coffeethai02_1{version}.json")
}
_ => "".to_string(),
}
}
}
pub async fn throttle_send_recipe(
recipe: &Recipe,
tx: &Sender<TxControlMessage>,
country: String,
version: String,
uid: Arc<Mutex<String>>,
) {
let r01s: Vec<Recipe01> = recipe
.Recipe01
.par_iter()
.flat_map(|x| {
let mut v = Vec::new();
v.push(x.clone());
if let Some(sub) = x.clone().SubMenu {
v.extend(sub);
}
v
})
.collect();
let matset: Vec<MaterialSetting> = recipe.MaterialSetting.clone();
// test stream start model
let ss = StreamDataStart::new(
r01s.len(),
CHUNK_SIZE,
Some(uid.try_lock().unwrap().to_string()),
);
let sid = ss.get_id();
info!("starting {sid}");
if let Some(err) = tx.send(TxControlMessage::Payload(ss.as_msg())).await.err() {
println!("ERR: send tx error, {err:?}");
}
// split send
let uidd = uid.try_lock().unwrap().to_string();
for (index, chunk) in r01s.chunks(CHUNK_SIZE).enumerate() {
let sda = StreamDataChunk::new(&sid, index * CHUNK_SIZE, chunk.to_vec(), uidd.to_string());
// no validate
if let Some(err) = tx.send(TxControlMessage::Payload(sda.as_msg())).await.err() {
println!("ERR: send tx error, {err:?}");
}
}
let mat_exid = sid.clone();
let extp = "matset";
for (index, chunk) in matset.chunks(CHUNK_SIZE).enumerate() {
let curr_ch_id = format!("{mat_exid}_{index}");
let extra_matset = StreamDataExtra::new(&curr_ch_id, &extp, chunk.to_vec());
if let Some(err) = tx
.send(TxControlMessage::Payload(extra_matset.as_msg()))
.await
.err()
{
println!("ERR: send tx extra error: {err:?}");
}
}
let extl = "topplist";
for (index, chunk) in recipe.Topping.ToppingList.chunks(CHUNK_SIZE).enumerate() {
let curr_ch_id = format!("{mat_exid}_tl{index}");
let extra_topplist = StreamDataExtra::new(&curr_ch_id, &extl, chunk.to_vec());
if let Some(err) = tx
.send(TxControlMessage::Payload(extra_topplist.as_msg()))
.await
.err()
{
println!("ERR: send tx extra2 error: {err:?}");
}
}
let extg = "toppgrp";
for (index, chunk) in recipe.Topping.ToppingGroup.chunks(CHUNK_SIZE).enumerate() {
let curr_ch_id = format!("{mat_exid}_tg{index}");
let extra_toppgrp = StreamDataExtra::new(&curr_ch_id, &extg, chunk.to_vec());
if let Some(err) = tx
.send(TxControlMessage::Payload(extra_toppgrp.as_msg()))
.await
.err()
{
println!("ERR: send tx extra2 error: {err:?}");
}
}
// NOTE: disable from case concurrent write may causes corrupted file
// let rp_clone = recipe.clone();
// tokio::task::spawn(async move {
// rp_clone.export_to_json_file(Some(format!("result.{country}.{version}.json")));
// });
info!("sending {sid}");
// return sid;
let end_msg = StreamDataEnd::new(&sid);
if let Some(err) = tx
.send(TxControlMessage::Payload(end_msg.as_msg()))
.await
.err()
{
println!("ERR: send tx error, {err:?}");
}
}
// TODO: split cases into sub function
pub async fn handle_recipe_request(
config: DevConfig,
redis: redis::Client,
tx: Sender<TxControlMessage>,
req: WebsocketMessageRequest,
uid_clone: Arc<Mutex<String>>,
) -> WebsocketMessageResult {
// guard expect value
let p = req.payload.unwrap();
let recipe_param: RecipeRequestPayload = serde_json::from_value(p)?;
// get actual version
//
let latest_key = format!("{country}/version", country = recipe_param.country);
let mut latest_version = match fetch_content_from_redis_byte(redis.clone(), &latest_key).await {
Ok(x) => {
// decode brotli
let mut sbuf = String::new();
let mut decoder = BrotliDecoder::new(x.as_slice());
match decoder.read_to_string(&mut sbuf).await {
Ok(_) => sbuf.replace('"', ""),
Err(e) => {
println!("decode fail: {e}");
"".to_string()
}
}
}
Err(e) => {
println!("get latest fail: {e}");
"".to_string()
}
};
if latest_version.is_empty() {
// cannot get actual version, try get from git
latest_version = match invoke_checkout_request(config.clone(), latest_key).await {
Ok(version) => version,
Err(e) => {
println!("Error on checkout: {e}");
"".to_string()
}
};
}
let req_file = if is_req_patch(&recipe_param) {
format!(
"stx_{country}_{version}.json",
country = recipe_param.country,
version = latest_version
)
} else {
format!(
"result.{country}.{version}.json",
country = recipe_param.country,
version = latest_version
)
};
let mut retry_cnt = 0;
println!("init req: {req_file}");
match get_local_file(req_file) {
Ok(mut f) => {
println!("get local file ok");
let mut file_content = String::new();
f.read_to_string(&mut file_content)?;
if !file_content.is_empty() {
info!("local file -> buffer OK");
}
// split send
let recipe: Recipe = match serde_json::from_str(&file_content) {
Ok(c) => c,
Err(e) => {
error!("error deserialize struct fail, file may be corrupted: {e:?}");
if !file_content.ends_with("}") {
error!("File corrupted, invalid json format");
}
let _ = tx.send(TxControlMessage::Payload(serde_json::json!({
"type": "notify",
"payload": {
"from": "system_tx",
"level": "error",
"msg": format!("Some requested file on cache is corrupt, {} version {}", recipe_param.country, latest_version),
"to": ""
}
}))).await;
return Err(e.into());
}
};
throttle_send_recipe(
&recipe,
&tx,
recipe_param.country,
latest_version,
uid_clone.clone(),
)
.await;
}
Err(_) => {
println!("retry by fetching git");
let lvc = latest_version.clone();
// concurrent fetch
for i in 1..6 {
let latest_version_c = lvc.clone();
retry_cnt = i;
// retry #1: get from redis
let r1_key = get_key_cache(
recipe_param.clone().country,
latest_version_c.clone(),
is_req_patch(&recipe_param),
retry_cnt,
);
println!("curr key: {r1_key}");
if retry_cnt < 3 {
match fetch_content_from_redis_byte(redis.clone(), &r1_key).await {
Ok(res) => {
let buf = BufReader::new(res.as_slice());
let mut sbuf = String::new();
let mut decoder = BrotliDecoder::new(buf);
if let Ok(_) = decoder.read_to_string(&mut sbuf).await {
let recipe: Recipe = serde_json::from_str(&sbuf)?;
throttle_send_recipe(
&recipe,
&tx,
recipe_param.country,
latest_version,
uid_clone.clone(),
)
.await;
break;
}
}
Err(_) => {}
}
} else {
// retry get from git
let content = match invoke_checkout_request(config.clone(), r1_key).await {
Ok(file_content) => file_content,
Err(e) => {
println!("Error on checkout: {e}");
"".to_string()
}
};
let recipe = serde_json::from_str::<Recipe>(&content);
if let Ok(rp) = recipe {
throttle_send_recipe(
&rp,
&tx,
recipe_param.clone().country,
latest_version_c.clone(),
uid_clone.clone(),
)
.await;
break;
} else {
info!("fail to deserialize: {}", content);
}
}
}
}
}
Ok(())
}

View file

@ -0,0 +1,61 @@
use crate::websocket::{core::*, model::*};
use log::{error, info};
use redis::TypedCommands;
/// Handle request of sheet type from websocket (read)
pub async fn handle_sheet_request(
redis: redis::Client,
req: WebsocketMessageRequest,
) -> WebsocketMessageResult {
// CommandRequestPayload struct-like
let req_clone = req.clone();
// we can assume the payload is existed from handler
let payload_sheet_request: CommandRequestPayload =
match serde_json::from_value(req.payload.unwrap()) {
Ok(sreq) => sreq,
Err(e) => {
error!("error deserialize body sheet request: {e:?} ---> Skip");
// continue;
return Err(format!("unexpected sheet body: {e:?}").into());
}
};
info!(
"get sheet request: {}, {:?}",
payload_sheet_request.srv_name, payload_sheet_request
);
let parameters = payload_sheet_request
.values
.get("param")
.unwrap_or_default();
// TODO: will be changed to config from yaml file
let ch_target = if let Some(pm) = parameters.as_str() {
match pm {
"get_all_catalogs" => "catalogs",
"get_catalog" | "enter" => "enter",
"heartbeat" => "heartbeat",
_ => "junk",
}
} else {
"junk"
};
let channel = format!("{}/{}", payload_sheet_request.srv_name, ch_target);
info!("publishing to {channel}");
let mut rcl = redis.clone();
let pub_res = rcl.publish(
channel,
serde_json::to_string(&req_clone.clone()).unwrap_or("{}".to_string()),
);
if let Err(e) = pub_res {
error!("error on publish result cmd: {e:?}");
}
Ok(())
}

View file

@ -0,0 +1,49 @@
use crate::{app::Hub, websocket::core::*};
use log::{debug, info, warn};
use std::{sync::Arc, time::Duration};
use tokio::{
sync::{Mutex, mpsc::Sender},
task::JoinHandle,
time::Instant,
};
pub async fn get_watchdog_task(
tx: Sender<TxControlMessage>,
watchdog_last_seen: Arc<Mutex<Instant>>,
user: Arc<Mutex<String>>,
hub: Arc<Mutex<Hub>>,
) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
{
let h = hub.try_lock().unwrap();
let curr_user = user.try_lock().unwrap().to_string();
info!("{}: checking invalid ...", curr_user);
if h.clients.contains_key(&curr_user) && curr_user.starts_with("temp") {
warn!("detect unauthorized -- {}", curr_user);
let _ = tx
.send(TxControlMessage::Payload(serde_json::json!({
"timeout": "watchdog"
})))
.await;
break;
}
}
let last = *watchdog_last_seen.lock().await;
if last.elapsed() > TIMEOUT {
warn!("Timeout close connection");
let _ = tx
.send(TxControlMessage::Payload(serde_json::json!({
"timeout": "watchdog"
})))
.await;
break;
}
}
})
}