1use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use base64::Engine;
10use bytes::Bytes;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13
14use crate::{
15 Transport::{
16 Strategy::{TransportStats, TransportStrategy, TransportType},
17 TransportConfig,
18 },
19 WASM::{
20 HostBridge::HostBridgeImpl,
21 MemoryManager::{MemoryLimits, MemoryManagerImpl},
22 Runtime::{WASMConfig, WASMRuntime},
23 WASMStats,
24 },
25 dev_log,
26};
27
28#[derive(Clone, Debug)]
30pub struct WASMTransportImpl {
31 runtime:Arc<WASMRuntime>,
33
34 memory_manager:Arc<RwLock<MemoryManagerImpl>>,
36
37 bridge:Arc<HostBridgeImpl>,
39
40 modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
42
43 #[allow(dead_code)]
45 config:TransportConfig,
46
47 connected:Arc<RwLock<bool>>,
49
50 stats:Arc<RwLock<TransportStats>>,
52}
53
54#[derive(Debug, Clone)]
56pub struct WASMModuleInfo {
57 pub id:String,
59
60 pub name:Option<String>,
62
63 pub path:Option<PathBuf>,
65
66 pub loaded_at:u64,
68
69 pub function_stats:HashMap<String, FunctionCallStats>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct FunctionCallStats {
76 pub call_count:u64,
78
79 pub total_time_us:u64,
81
82 pub last_call_at:Option<u64>,
84
85 pub error_count:u64,
87}
88
89impl FunctionCallStats {
90 pub fn record_call(&mut self, time_us:u64) {
92 self.call_count += 1;
93
94 self.total_time_us += time_us;
95
96 self.last_call_at = Some(
97 std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .map(|d| d.as_secs())
100 .unwrap_or(0),
101 );
102 }
103
104 pub fn record_error(&mut self) { self.error_count += 1; }
106}
107
108impl Default for FunctionCallStats {
109 fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
110}
111
112impl WASMTransportImpl {
113 pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
115 let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
116
117 let runtime_result = tokio::runtime::Runtime::new()
120 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
121 .block_on(WASMRuntime::new(config.clone()))
122 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
123
124 let runtime = Arc::new(runtime_result);
125
126 let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
127
128 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
129
130 let bridge = Arc::new(HostBridgeImpl::new());
131
132 Ok(Self {
133 runtime,
134 memory_manager,
135 bridge,
136 modules:Arc::new(RwLock::new(HashMap::new())),
137 config:TransportConfig::default(),
138 connected:Arc::new(RwLock::new(true)), stats:Arc::new(RwLock::new(TransportStats::default())),
140 })
141 }
142
143 pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
145 let runtime_result = tokio::runtime::Runtime::new()
146 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
147 .block_on(WASMRuntime::new(wasm_config.clone()))
148 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
149
150 let runtime = Arc::new(runtime_result);
151
152 let memory_limits = MemoryLimits::new(
153 wasm_config.memory_limit_mb,
154 (wasm_config.memory_limit_mb as f64 * 0.75) as u64,
155 100,
156 );
157
158 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
159
160 let bridge = Arc::new(HostBridgeImpl::new());
161
162 Ok(Self {
163 runtime,
164 memory_manager,
165 bridge,
166 modules:Arc::new(RwLock::new(HashMap::new())),
167 config:transport_config,
168 connected:Arc::new(RwLock::new(true)),
169 stats:Arc::new(RwLock::new(TransportStats::default())),
170 })
171 }
172
173 pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
175
176 pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
178
179 pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
181
182 pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
184
185 pub async fn get_wasm_stats(&self) -> WASMStats {
187 let memory_manager = self.memory_manager.read().await;
188
189 let managers = self.modules.read().await;
190
191 WASMStats {
192 modules_loaded:managers.len(),
193
194 active_instances:managers.len(), total_memory_mb:memory_manager.current_usage_mb() as u64,
196
197 total_execution_time_ms:0, function_calls:self.stats.read().await.messages_sent,
199 }
200 }
201
202 pub async fn call_wasm_function(
204 &self,
205
206 module_id:&str,
207
208 function_name:&str,
209
210 args:Vec<Bytes>,
211 ) -> anyhow::Result<Bytes> {
212 let start = std::time::Instant::now();
213
214 dev_log!(
215 "wasm",
216 "Calling WASM function: {}::{} with {} arguments",
217 module_id,
218 function_name,
219 args.len()
220 );
221
222 let modules = self.modules.read().await;
223
224 let _module = modules
225 .get(module_id)
226 .ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
227
228 let response = Bytes::new();
231
232 let mut modules_mut = self.modules.write().await;
234
235 if let Some(module) = modules_mut.get_mut(module_id) {
236 let stats = module.function_stats.entry(function_name.to_string()).or_default();
237
238 stats.record_call(start.elapsed().as_micros() as u64);
239 }
240
241 drop(modules_mut);
242
243 let mut stats = self.stats.write().await;
245
246 stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
247
248 stats.record_received(response.len() as u64);
249
250 Ok(response)
251 }
252}
253
254#[async_trait]
255impl TransportStrategy for WASMTransportImpl {
256 type Error = WASMTransportError;
257
258 async fn connect(&self) -> Result<(), Self::Error> {
259 dev_log!("transport", "WASM transport connecting");
260
261 *self.connected.write().await = true;
263
264 dev_log!("transport", "WASM transport connected");
265
266 Ok(())
267 }
268
269 async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
270 let start = std::time::Instant::now();
271
272 if !self.is_connected() {
273 return Err(WASMTransportError::NotConnected);
274 }
275
276 dev_log!("transport", "Sending WASM transport request ({} bytes)", request.len());
277
278 let request_str =
281 std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
282
283 let parts:Vec<&str> = request_str.splitn(3, ':').collect();
284
285 if parts.len() < 3 {
286 return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
287 }
288
289 let module_id = parts[0];
290
291 let function_name = parts[1];
292
293 let args_base64 = parts[2];
294
295 use base64::engine::general_purpose::STANDARD;
297
298 let args = vec![Bytes::from(
299 STANDARD
300 .decode(args_base64)
301 .map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
302 )];
303
304 let response = self
306 .call_wasm_function(module_id, function_name, args)
307 .await
308 .map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
309
310 let response_vec = response.to_vec();
312
313 let latency_us = start.elapsed().as_micros() as u64;
314
315 dev_log!("transport", "WASM transport request completed in {}µs", latency_us);
316
317 Ok(response_vec)
318 }
319
320 async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
321 if !self.is_connected() {
322 return Err(WASMTransportError::NotConnected);
323 }
324
325 dev_log!(
326 "transport",
327 "Sending WASM transport request without response ({} bytes)",
328 data.len()
329 );
330
331 self.send(data).await?;
333
334 Ok(())
335 }
336
337 async fn close(&self) -> Result<(), Self::Error> {
338 dev_log!("transport", "Closing WASM transport");
339
340 *self.connected.write().await = false;
341
342 dev_log!("transport", "WASM transport closed");
343
344 Ok(())
345 }
346
347 fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
348
349 fn transport_type(&self) -> TransportType { TransportType::WASM }
350}
351
352#[derive(Debug, thiserror::Error)]
354pub enum WASMTransportError {
355 #[error("Module not found: {0}")]
357 ModuleNotFound(String),
358
359 #[error("Function not found: {0}")]
361 FunctionNotFound(String),
362
363 #[error("Function call failed: {0}")]
365 FunctionCallFailed(String),
366
367 #[error("Memory error: {0}")]
369 MemoryError(String),
370
371 #[error("Runtime error: {0}")]
373 RuntimeError(String),
374
375 #[error("Invalid request: {0}")]
377 InvalidRequest(String),
378
379 #[error("Not connected")]
381 NotConnected,
382
383 #[error("Compilation failed: {0}")]
385 CompilationFailed(String),
386
387 #[error("Timeout")]
389 Timeout,
390}
391
392#[cfg(test)]
393mod tests {
394
395 use super::*;
396 use crate::Transport::Strategy::TransportStrategy;
397
398 #[test]
399 fn test_wasm_transport_creation() {
400 let result = WASMTransportImpl::new(true, 512, 30000);
401
402 assert!(result.is_ok());
403
404 let transport = result.unwrap();
405
406 assert!(transport.is_connected());
408 }
409
410 #[test]
411 fn test_function_call_stats() {
412 let mut stats = FunctionCallStats::default();
413
414 stats.record_call(100);
415
416 assert_eq!(stats.call_count, 1);
417
418 assert_eq!(stats.total_time_us, 100);
419
420 assert!(stats.last_call_at.is_some());
421 }
422
423 #[tokio::test]
424 async fn test_wasm_transport_not_connected_after_close() {
425 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
426
427 let _:anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
428
429 assert!(!transport.is_connected());
430 }
431
432 #[tokio::test]
433 async fn test_get_wasm_stats() {
434 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
435
436 let stats = transport.get_wasm_stats().await;
437
438 assert_eq!(stats.modules_loaded, 0);
439
440 assert_eq!(stats.active_instances, 0);
441 }
442}