Skip to main content

Grove/WASM/
FunctionExport.rs

1//! Function Export Module
2//!
3//! Handles exporting host functions to WASM modules.
4//! Provides registration and management of functions that WASM can call.
5
6use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use wasmtime::{Caller, Linker};
12
13use crate::{
14	WASM::HostBridge::{
15		FunctionSignature,
16		HostBridgeImpl,
17		HostBridgeImpl as HostBridge,
18		HostFunctionCallback,
19		ParamType,
20		ReturnType,
21	},
22	dev_log,
23};
24
25/// Host function registry for WASM exports
26pub struct HostFunctionRegistry {
27	/// Registered host functions
28	functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
29
30	/// Associated host bridge
31	#[allow(dead_code)]
32	bridge:Arc<HostBridge>,
33}
34
35/// Registered host function with metadata
36#[derive(Debug, Clone)]
37struct RegisteredHostFunction {
38	/// Function name
39	#[allow(dead_code)]
40	name:String,
41
42	/// Function signature
43	#[allow(dead_code)]
44	signature:FunctionSignature,
45
46	/// Synchronous callback
47	callback:Option<HostFunctionCallback>,
48
49	/// Registration timestamp
50	#[allow(dead_code)]
51	registered_at:u64,
52
53	/// Call statistics
54	stats:FunctionStats,
55}
56
57/// Function statistics
58#[derive(Debug, Clone, Default)]
59pub struct FunctionStats {
60	/// Number of times called
61	pub call_count:u64,
62
63	/// Total execution time in nanoseconds
64	pub total_execution_ns:u64,
65
66	/// Last call timestamp
67	pub last_call_at:Option<u64>,
68
69	/// Number of errors
70	pub error_count:u64,
71}
72
73/// Export configuration for WASM functions
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ExportConfig {
76	/// Enable function export by default
77	pub auto_export:bool,
78
79	/// Enable timing statistics
80	pub enable_stats:bool,
81
82	/// Maximum number of functions that can be exported
83	pub max_functions:usize,
84
85	/// Function name prefix for exports
86	pub name_prefix:Option<String>,
87}
88
89impl Default for ExportConfig {
90	fn default() -> Self {
91		Self {
92			auto_export:true,
93
94			enable_stats:true,
95
96			max_functions:1000,
97
98			name_prefix:Some("host_".to_string()),
99		}
100	}
101}
102
103/// Function export for WASM
104pub struct FunctionExportImpl {
105	registry:Arc<HostFunctionRegistry>,
106
107	config:ExportConfig,
108}
109
110impl FunctionExportImpl {
111	/// Create a new function export manager
112	pub fn new(bridge:Arc<HostBridge>) -> Self {
113		Self {
114			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
115
116			config:ExportConfig::default(),
117		}
118	}
119
120	/// Create with custom configuration
121	pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
122		Self {
123			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
124
125			config,
126		}
127	}
128
129	/// Register a host function for export to WASM
130	pub async fn register_function(
131		&self,
132
133		name:&str,
134
135		signature:FunctionSignature,
136
137		callback:HostFunctionCallback,
138	) -> Result<()> {
139		dev_log!("wasm", "Registering host function for export: {}", name);
140
141		let functions = self.registry.functions.read().await;
142
143		// Check max function limit
144		if functions.len() >= self.config.max_functions {
145			return Err(anyhow::anyhow!(
146				"Maximum number of exported functions reached: {}",
147				self.config.max_functions
148			));
149		}
150
151		drop(functions);
152
153		let mut functions = self.registry.functions.write().await;
154
155		let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
156
157		functions.insert(
158			name.to_string(),
159			RegisteredHostFunction {
160				name:name.to_string(),
161				signature,
162				callback:Some(callback),
163				registered_at,
164				stats:FunctionStats::default(),
165			},
166		);
167
168		dev_log!("wasm", "Host function registered for WASM export: {}", name);
169
170		Ok(())
171	}
172
173	/// Register multiple host functions
174	pub async fn register_functions(
175		&self,
176
177		signatures:Vec<FunctionSignature>,
178
179		callbacks:Vec<HostFunctionCallback>,
180	) -> Result<()> {
181		if signatures.len() != callbacks.len() {
182			return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
183		}
184
185		for (sig, callback) in signatures.into_iter().zip(callbacks) {
186			let name = sig.name.clone();
187
188			self.register_function(&name, sig, callback).await?;
189		}
190
191		Ok(())
192	}
193
194	/// Export all registered functions to a WASMtime linker
195	pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
196	where
197		T: Send + 'static, {
198		dev_log!(
199			"wasm",
200			"Exporting {} host functions to linker",
201			self.registry.functions.read().await.len()
202		);
203
204		let functions = self.registry.functions.read().await;
205
206		for (name, func) in functions.iter() {
207			self.export_single_function(linker, name, func)?;
208		}
209
210		dev_log!("wasm", "All host functions exported to linker");
211
212		Ok(())
213	}
214
215	/// Export a single function to the linker
216	fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
217	where
218		T: Send + 'static, {
219		dev_log!("wasm", "Exporting function: {}", name);
220
221		let callback = func
222			.callback
223			.ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
224
225		let func_name = if let Some(prefix) = &self.config.name_prefix {
226			format!("{}{}", prefix, name)
227		} else {
228			name.to_string()
229		};
230
231		let func_name_for_debug = func_name.clone();
232
233		let func_name_inner = func_name.clone();
234
235		// Create a wrapper function that handles stats and error handling
236		let _wrapped_callback =
237			move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
238				let _start = std::time::Instant::now();
239
240				// Convert args to bytes
241				let args_bytes:Result<Vec<bytes::Bytes>, _> = args
242					.iter()
243					.map(|arg| {
244						match arg {
245							wasmtime::Val::I32(i) => {
246								serde_json::to_vec(i)
247									.map(bytes::Bytes::from)
248									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
249							},
250							wasmtime::Val::I64(i) => {
251								serde_json::to_vec(i)
252									.map(bytes::Bytes::from)
253									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
254							},
255							wasmtime::Val::F32(f) => {
256								serde_json::to_vec(f)
257									.map(bytes::Bytes::from)
258									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
259							},
260							wasmtime::Val::F64(f) => {
261								serde_json::to_vec(f)
262									.map(bytes::Bytes::from)
263									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
264							},
265							_ => Err(anyhow::anyhow!("Unsupported argument type")),
266						}
267					})
268					.collect();
269
270				let args_bytes = args_bytes.map_err(|_| {
271					dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
272					wasmtime::Trap::StackOverflow
273				})?;
274
275				// Call the callback
276				let result = callback(args_bytes);
277
278				match result {
279					Ok(response_bytes) => {
280						// Deserialize response
281						let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
282							dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
283							wasmtime::Trap::StackOverflow
284						})?;
285
286						let ret_val = match result_val {
287							serde_json::Value::Number(n) => {
288								if let Some(i) = n.as_i64() {
289									wasmtime::Val::I32(i as i32)
290								} else if let Some(f) = n.as_f64() {
291									wasmtime::Val::I64(f as i64)
292								} else {
293									dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
294
295									return Err(wasmtime::Trap::StackOverflow);
296								}
297							},
298
299							_ => {
300								dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
301
302								return Err(wasmtime::Trap::StackOverflow);
303							},
304						};
305
306						Ok(vec![ret_val])
307					},
308
309					Err(e) => {
310						// Error handling
311						dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
312
313						Err(wasmtime::Trap::StackOverflow)
314					},
315				}
316			};
317
318		// Define the function signature for WASMtime
319		let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
320
321		// Register host function with the linker using simple i32->i32 signature
322		// In Wasmtime 20, func_wrap expects parameters to be inferred from the closure
323		// signature
324		let func_name_for_logging = func_name.clone();
325
326		linker
327			.func_wrap(
328				"_host", // Module name for host functions
329				&func_name,
330				move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
331					// Track function call for metrics
332					let start = std::time::Instant::now();
333
334					// Convert input parameter to bytes for callback
335					let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
336						Ok(b) => b,
337						Err(e) => {
338							dev_log!(
339								"wasm",
340								"warn: serialization error for function '{}': {}",
341								func_name_for_logging,
342								e
343							);
344							return -1i32;
345						},
346					};
347
348					// Call the registered callback
349					let result = callback(vec![args_bytes]);
350
351					match result {
352						Ok(response_bytes) => {
353							// Deserialize response
354							let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
355								Ok(v) => v,
356								Err(_) => {
357									dev_log!(
358										"wasm",
359										"warn: error deserializing response for function '{}'",
360										func_name_for_logging
361									);
362									return -1i32;
363								},
364							};
365
366							// Extract result value
367							let ret_val = match result_val {
368								serde_json::Value::Number(n) => {
369									if let Some(i) = n.as_i64() {
370										i as i32
371									} else if let Some(f) = n.as_f64() {
372										f as i32
373									} else {
374										dev_log!(
375											"wasm",
376											"warn: invalid number format for function '{}'",
377											func_name_for_logging
378										);
379										-1i32
380									}
381								},
382								serde_json::Value::Bool(b) => {
383									if b {
384										1i32
385									} else {
386										0i32
387									}
388								},
389								_ => {
390									dev_log!(
391										"wasm",
392										"warn: unsupported response type for function '{}', expected number or bool",
393										func_name_for_logging
394									);
395									-1i32
396								},
397							};
398
399							// Log successful call
400							let duration = start.elapsed();
401							dev_log!(
402								"wasm",
403								"[FunctionExport] Host function '{}' executed successfully in {}µs",
404								func_name_for_logging,
405								duration.as_micros()
406							);
407
408							ret_val
409						},
410						Err(e) => {
411							// Error handling - return error code to WASM caller
412							dev_log!(
413								"wasm",
414								"[FunctionExport] Host function '{}' returned error: {}",
415								func_name_for_logging,
416								e
417							);
418							// Return -1 to indicate error to WASM caller
419							-1i32
420						},
421					}
422				},
423			)
424			.map_err(|e| {
425				dev_log!(
426					"wasm",
427					"warn: [FunctionExport] failed to wrap host function '{}': {}",
428					func_name_for_debug,
429					e
430				);
431				e
432			})?;
433
434		dev_log!(
435			"wasm",
436			"[FunctionExport] Host function '{}' registered successfully",
437			func_name_for_debug
438		);
439
440		Ok(())
441	}
442
443	/// Convert our signature to WASMtime signature type
444	#[allow(dead_code)]
445	fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
446		// This is a placeholder - actual implementation depends on the exact types
447		// In production, this would map ParamType and ReturnType to WASMtime types
448		Ok(wasmparser::FuncType::new([], []))
449	}
450
451	/// Get all registered function names
452	pub async fn get_function_names(&self) -> Vec<String> {
453		self.registry.functions.read().await.keys().cloned().collect()
454	}
455
456	/// Get function statistics
457	pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
458		self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
459	}
460
461	/// Unregister a function
462	pub async fn unregister_function(&self, name:&str) -> Result<bool> {
463		let mut functions = self.registry.functions.write().await;
464
465		let removed = functions.remove(name).is_some();
466
467		if removed {
468			dev_log!("wasm", "Unregistered host function: {}", name);
469		} else {
470			dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
471		}
472
473		Ok(removed)
474	}
475
476	/// Clear all registered functions
477	pub async fn clear(&self) {
478		dev_log!("wasm", "Clearing all registered host functions");
479
480		self.registry.functions.write().await.clear();
481	}
482}
483
484#[cfg(test)]
485mod tests {
486
487	use super::*;
488
489	#[tokio::test]
490	async fn test_function_export_creation() {
491		let bridge = Arc::new(HostBridgeImpl::new());
492
493		let export = FunctionExportImpl::new(bridge);
494
495		assert_eq!(export.get_function_names().await.len(), 0);
496	}
497
498	#[tokio::test]
499	async fn test_register_function() {
500		let bridge = Arc::new(HostBridgeImpl::new());
501
502		let export = FunctionExportImpl::new(bridge);
503
504		let signature = FunctionSignature {
505			name:"echo".to_string(),
506
507			param_types:vec![ParamType::I32],
508
509			return_type:Some(ReturnType::I32),
510
511			is_async:false,
512		};
513
514		let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
515
516		let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
517
518		assert!(result.is_ok());
519
520		assert_eq!(export.get_function_names().await.len(), 1);
521	}
522
523	#[tokio::test]
524	async fn test_unregister_function() {
525		let bridge = Arc::new(HostBridgeImpl::new());
526
527		let export = FunctionExportImpl::new(bridge);
528
529		let signature = FunctionSignature {
530			name:"test".to_string(),
531
532			param_types:vec![ParamType::I32],
533
534			return_type:Some(ReturnType::I32),
535
536			is_async:false,
537		};
538
539		let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
540
541		let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
542
543		let result:bool = export.unregister_function("test").await.unwrap();
544
545		assert!(result);
546
547		assert_eq!(export.get_function_names().await.len(), 0);
548	}
549
550	#[test]
551	fn test_export_config_default() {
552		let config = ExportConfig::default();
553
554		assert_eq!(config.auto_export, true);
555
556		assert_eq!(config.max_functions, 1000);
557	}
558
559	#[test]
560	fn test_function_stats_default() {
561		let stats = FunctionStats::default();
562
563		assert_eq!(stats.call_count, 0);
564
565		assert_eq!(stats.error_count, 0);
566	}
567}