diff --git a/src/stream/model.rs b/src/stream/model.rs index 7199725..7b61e6e 100644 --- a/src/stream/model.rs +++ b/src/stream/model.rs @@ -20,6 +20,8 @@ pub struct StreamDataStart { #[serde(rename = "ref")] #[serde(skip_serializing_if = "Option::is_none")] pub stream_ref: Option, + /// extra data, information + pub metadata: String, } impl IntoStreamMessage for StreamDataStart { @@ -35,7 +37,7 @@ impl IntoStreamMessage for StreamDataStart { serde_json::json!({ "type": StreamDataStart::MSG_NAME, - "payload": self.clone() + "payload": payload.clone() }) } @@ -45,12 +47,18 @@ impl IntoStreamMessage for StreamDataStart { } impl StreamDataStart { - pub fn new(total_size: usize, chunk_size: usize, stream_ref: Option) -> Self { + pub fn new( + total_size: usize, + chunk_size: usize, + stream_ref: Option, + metadata: String, + ) -> Self { Self { stream_id: Uuid::new_v4().to_string(), total_size, chunk_size, stream_ref, + metadata, } } @@ -113,6 +121,8 @@ where pub struct StreamDataEnd { /// Uuid v4, client must mapping later values with this stream id pub stream_id: String, + /// endpoint user + pub to: String, } impl IntoStreamMessage for StreamDataEnd { @@ -131,9 +141,10 @@ impl IntoStreamMessage for StreamDataEnd { } impl StreamDataEnd { - pub fn new(sid: &str) -> Self { + pub fn new(sid: &str, to: String) -> Self { Self { stream_id: sid.to_string(), + to, } } diff --git a/src/websocket/rw.rs b/src/websocket/rw.rs index 565d7a4..d8897f0 100644 --- a/src/websocket/rw.rs +++ b/src/websocket/rw.rs @@ -51,7 +51,7 @@ pub async fn read( Message::Text(t) => { let req: WebsocketMessageRequest = serde_json::from_str(t.as_str())?; - info!("get msg: {}", req.type_w); + // info!("get msg: {}", req.type_w); match req.type_w.as_str() { "recipe" if req.payload.is_some() => { tasks::recipe::handle_recipe_request( @@ -63,6 +63,16 @@ pub async fn read( ) .await?; } + "recipe_versions" if req.payload.is_some() => { + tasks::recipe::handle_recipe_versions_list_request( + config.clone(), + redis.clone(), + tx.clone(), + req, + uid_clone.clone(), + ) + .await?; + } "command" if req.payload.is_some() => { tasks::command::handle_command_request(state.clone(), tx.clone(), req) .await?; @@ -201,6 +211,8 @@ pub async fn write( } let _ = sender.send(res.to_string().into()).await; + } else { + warn!("failed to send message, as the receiver not detected: {res:?}"); } } TxControlMessage::CloseExist => { diff --git a/src/websocket/tasks/recipe.rs b/src/websocket/tasks/recipe.rs index 90f56e6..d07cf34 100644 --- a/src/websocket/tasks/recipe.rs +++ b/src/websocket/tasks/recipe.rs @@ -4,6 +4,7 @@ use crate::stream::model::{ }; use crate::websocket::{core::*, helper::*, model::*}; +use std::collections::HashMap; use std::{fs::File, io::Read, path::PathBuf, sync::Arc}; use async_compression::tokio::bufread::BrotliDecoder; @@ -63,6 +64,24 @@ pub fn get_key_cache(country: String, version: String, is_patch: bool, retry_cnt } } +pub fn get_extra_parameters(s: String) -> HashMap { + let mut result = HashMap::new(); + + let plist: Vec = s.split(",").map(|x| x.to_string()).collect(); + + for pl in plist { + let sm: Vec = pl.split("=").map(|x| x.to_string()).collect(); + + if sm.len() != 2 { + continue; + } + + result.insert(sm[0].to_string(), sm[1].to_string()); + } + + result +} + pub async fn throttle_send_recipe( recipe: &Recipe, tx: &Sender, @@ -92,6 +111,7 @@ pub async fn throttle_send_recipe( r01s.len(), CHUNK_SIZE, Some(uid.try_lock().unwrap().to_string()), + format!("version={version},country={country}").to_string(), ); let sid = ss.get_id(); @@ -163,7 +183,7 @@ pub async fn throttle_send_recipe( info!("sending {sid}"); // return sid; - let end_msg = StreamDataEnd::new(&sid); + let end_msg = StreamDataEnd::new(&sid, uidd.clone()); if let Some(err) = tx .send(TxControlMessage::Payload(end_msg.as_msg())) @@ -222,6 +242,23 @@ pub async fn handle_recipe_request( }; } + // detect if use different version + // parameter: use_legacy_version=true,version=888 + if let Some(extra_param) = recipe_param.clone().parameters { + let pmap = get_extra_parameters(extra_param); + + latest_version = if pmap.contains_key("use_legacy_version") + && let Some(legacy_cfg) = pmap.get("use_legacy_version") + && legacy_cfg.eq("true") + { + pmap.get("version").unwrap_or(&latest_version).to_string() + } else { + latest_version + }; + + info!("after param in recipe: {latest_version}"); + } + let req_file = if is_req_patch(&recipe_param) { format!( "stx_{country}_{version}.json", @@ -353,3 +390,48 @@ pub async fn handle_recipe_request( Ok(()) } + +pub async fn handle_recipe_versions_list_request( + config: DevConfig, + redis: redis::Client, + tx: Sender, + req: WebsocketMessageRequest, + uid_clone: Arc>, +) -> WebsocketMessageResult { + println!("trigger check versions ... "); + let p = req.payload.unwrap(); + let recipe_param: RecipeRequestPayload = serde_json::from_value(p)?; + + let version_list = format!("{country}", country = recipe_param.country); + + let country_versions_str = match invoke_checkout_request(config.clone(), version_list).await { + Ok(vs) => vs, + Err(e) => return Err(format!("Cannot find versions of expected country: {e:?}").into()), + }; + + // extract version as list + let files: Vec = country_versions_str + .split(",") + .map(|x| x.to_string()) + .collect(); + + let result: Vec = files + .iter() + .filter(|x| x.starts_with("coffeethai02") && x.ends_with(".json")) + .map(|x| x.replace("coffeethai02_", "").replace(".json", "")) + .collect(); + + let uidd = uid_clone.clone().try_lock().unwrap().to_string(); + + let _ = tx + .send(TxControlMessage::Payload(serde_json::json!({ + "type": "version_selectors", + "payload": { + "versions": result, + "to": uidd + } + }))) + .await; + + Ok(()) +} diff --git a/src/websocket/tasks/watchdog.rs b/src/websocket/tasks/watchdog.rs index 4c2ddc3..1092f9d 100644 --- a/src/websocket/tasks/watchdog.rs +++ b/src/websocket/tasks/watchdog.rs @@ -21,7 +21,7 @@ pub async fn get_watchdog_task( let h = hub.try_lock().unwrap(); let curr_user = user.try_lock().unwrap().to_string(); - info!("{}: checking invalid ...", curr_user); + // info!("{}: checking invalid ...", curr_user); if h.clients.contains_key(&curr_user) && curr_user.starts_with("temp") { warn!("detect unauthorized -- {}", curr_user);