From d3e60a2702600c74e86a729e7c58575b82cb24ba Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Thu, 4 Jun 2026 17:52:46 +0530 Subject: [PATCH] feat: Trace generation --- .gitignore | 10 +++ Cargo.lock | 99 +++++++++++++++++++++++- Cargo.toml | 1 + src/main.rs | 213 +++++++++++++++++++++++++++++---------------------- src/trace.rs | 157 +++++++++++++++++++++++++++++++++++++ 5 files changed, 389 insertions(+), 91 deletions(-) create mode 100644 src/trace.rs diff --git a/.gitignore b/.gitignore index ea8c4bf..1af7c8a 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,11 @@ /target +/traces + + + + + + + + + diff --git a/Cargo.lock b/Cargo.lock index 42b0ab1..65658f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,27 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" + [[package]] name = "aws-lc-rs" version = "1.17.0" @@ -78,6 +93,19 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "cmake" version = "0.1.58" @@ -162,7 +190,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -383,6 +411,30 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.2.0" @@ -635,6 +687,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -822,6 +883,7 @@ dependencies = [ name = "ren" version = "0.1.0" dependencies = [ + "chrono", "reqwest", "serde", "serde_json", @@ -1515,6 +1577,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index 18c94b5..f04e4ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +chrono = "0.4.44" reqwest = { version = "0.13.4", features = ["json"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.150" diff --git a/src/main.rs b/src/main.rs index a3888fe..879b90b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,29 @@ +mod trace; + use serde::{Deserialize, Serialize}; use serde_json::json; use std::process::Command; +use std::time::Instant; +use trace::Trace; // ── Message Types ───────────────────────────────────────────────────────────── -// -// A single unified struct covers all four roles (system, user, assistant, tool). -// Optional fields are skipped during serialization so the wire format stays clean. #[derive(Serialize, Deserialize, Debug, Clone)] -struct ChatMessage { - role: String, +pub struct ChatMessage { + pub role: String, #[serde(skip_serializing_if = "Option::is_none")] - content: Option, + pub content: Option, - // Only present on role: "tool" messages #[serde(skip_serializing_if = "Option::is_none")] - tool_call_id: Option, + pub tool_call_id: Option, - // Only present on role: "assistant" messages that invoke tools #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, + pub tool_calls: Option>, } impl ChatMessage { - fn system(content: &str) -> Self { + pub fn system(content: &str) -> Self { Self { role: "system".into(), content: Some(content.into()), @@ -32,8 +31,7 @@ impl ChatMessage { tool_calls: None, } } - - fn user(content: &str) -> Self { + pub fn user(content: &str) -> Self { Self { role: "user".into(), content: Some(content.into()), @@ -41,11 +39,7 @@ impl ChatMessage { tool_calls: None, } } - - // Captures the assistant turn including any tool_calls it made. - // This must be pushed into history before the tool result messages, - // or most backends will reject the conversation. - fn assistant(content: Option, tool_calls: Option>) -> Self { + pub fn assistant(content: Option, tool_calls: Option>) -> Self { Self { role: "assistant".into(), content, @@ -53,9 +47,7 @@ impl ChatMessage { tool_calls, } } - - // One tool result message per tool call, keyed by the call's id. - fn tool_result(tool_call_id: &str, content: &str) -> Self { + pub fn tool_result(tool_call_id: &str, content: &str) -> Self { Self { role: "tool".into(), content: Some(content.into()), @@ -67,20 +59,18 @@ impl ChatMessage { // ── API Wire Types ──────────────────────────────────────────────────────────── -// Needs both Serialize + Deserialize: the model returns ToolCalls in its -// response, and we push them straight back into the next request's messages. #[derive(Serialize, Deserialize, Debug, Clone)] -struct ToolCall { - id: String, +pub struct ToolCall { + pub id: String, #[serde(rename = "type")] - call_type: String, // always "function" - function: FunctionCall, + pub call_type: String, + pub function: FunctionCall, } #[derive(Serialize, Deserialize, Debug, Clone)] -struct FunctionCall { - name: String, - arguments: String, // JSON-encoded string from the model +pub struct FunctionCall { + pub name: String, + pub arguments: String, } #[derive(Serialize)] @@ -116,18 +106,39 @@ struct Choice { message: ResponseMessage, } -// content is Option because some models return null content -// on assistant turns that contain only tool_calls. #[derive(Deserialize, Debug)] struct ResponseMessage { content: Option, tool_calls: Option>, } -// ── Tool Registry ───────────────────────────────────────────────────────────── +// ── Agent Output ────────────────────────────────────────────────────────────── // -// This will eventually be loaded from a manifest file / plugin system. -// For now it's a plain Vec that gets passed into every LLM request. +// run_agent always returns AgentOutput — even on convergence failure. +// This guarantees the message history is always accessible for trace recording, +// regardless of whether the run succeeded. +// +// The outer Result<_, _> is reserved for hard, unrecoverable failures: +// network errors, JSON deserialization failures, etc. + +pub struct AgentOutput { + /// The assistant's final text response. None if the run failed to converge. + pub answer: Option, + /// Set if the run hit MAX_ITERATIONS without a final answer. + pub error: Option, + /// Full message history, always returned so the trace can be saved. + pub messages: Vec, + pub iterations: usize, + pub duration_ms: u64, +} + +impl AgentOutput { + pub fn succeeded(&self) -> bool { + self.error.is_none() && self.answer.is_some() + } +} + +// ── Tool Registry ───────────────────────────────────────────────────────────── fn build_tools() -> Vec { vec![ @@ -139,15 +150,8 @@ fn build_tools() -> Vec { parameters: json!({ "type": "object", "properties": { - "url": { - "type": "string", - "description": "The full HTTP/HTTPS URL to query" - }, - "method": { - "type": "string", - "enum": ["GET", "POST"], - "default": "GET" - } + "url": { "type": "string", "description": "The full HTTP/HTTPS URL to query" }, + "method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" } }, "required": ["url"] }), @@ -157,14 +161,11 @@ fn build_tools() -> Vec { tool_type: "function".into(), function: ToolFunction { name: "run_shell".into(), - description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(), + description: "Run a shell command and return stdout + stderr. Use for filesystem ops, text processing, etc.".into(), parameters: json!({ "type": "object", "properties": { - "command": { - "type": "string", - "description": "The shell command to execute" - } + "command": { "type": "string", "description": "The shell command to execute" } }, "required": ["command"] }), @@ -174,12 +175,8 @@ fn build_tools() -> Vec { } // ── Tool Execution ──────────────────────────────────────────────────────────── -// -// Each tool returns a String that gets fed back as a tool_result message. -// Truncation is intentional: we don't want one noisy tool call consuming -// the entire remaining context budget. -const MAX_TOOL_OUTPUT: usize = 2_000; // chars +const MAX_TOOL_OUTPUT: usize = 2_000; fn truncate(s: String) -> String { if s.len() > MAX_TOOL_OUTPUT { @@ -201,9 +198,7 @@ fn execute_tool(call: &ToolCall) -> String { "run_curl" => { let url = args["url"].as_str().unwrap_or(""); let method = args["method"].as_str().unwrap_or("GET"); - println!(" [tool::run_curl] {} {}", method, url); - match Command::new("curl") .args(["-s", "-X", method, url]) .output() @@ -212,12 +207,9 @@ fn execute_tool(call: &ToolCall) -> String { Err(e) => format!("[error] curl failed: {e}"), } } - "run_shell" => { let command = args["command"].as_str().unwrap_or(""); - println!(" [tool::run_shell] {}", command); - match Command::new("sh").args(["-c", command]).output() { Ok(out) => { let stdout = String::from_utf8_lossy(&out.stdout).into_owned(); @@ -232,12 +224,11 @@ fn execute_tool(call: &ToolCall) -> String { Err(e) => format!("[error] shell failed: {e}"), } } - unknown => format!("[error] unknown tool: {unknown}"), } } -// ── LLM Client ─────────────────────────────────────────────────────────────── +// ── LLM Client ──────────────────────────────────────────────────────────────── async fn chat( client: &reqwest::Client, @@ -268,16 +259,6 @@ async fn chat( } // ── Agent Loop ──────────────────────────────────────────────────────────────── -// -// Iteration contract: -// 1. Send messages → LLM -// 2. Push the assistant turn into history (always, even if it has tool_calls) -// 3a. If tool_calls present: execute each, push tool_result messages, repeat. -// 3b. If no tool_calls: the content is the final answer, return it. -// -// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should -// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error -// rather than silently returning a partial result. const MAX_ITERATIONS: usize = 10; @@ -285,16 +266,14 @@ async fn run_agent( client: &reqwest::Client, tools: &[Tool], mut messages: Vec, -) -> Result> { +) -> Result> { + let start = Instant::now(); + for i in 0..MAX_ITERATIONS { println!("\n── iteration {} ──────────────────────────────", i + 1); let response = chat(client, &messages, tools).await?; - // Step 2: push the full assistant turn before adding tool results. - // Backends that validate conversation structure will reject the request - // if tool_result messages appear without a preceding assistant message - // that contains the matching tool_call ids. messages.push(ChatMessage::assistant( response.content.clone(), response.tool_calls.clone(), @@ -306,26 +285,32 @@ async fn run_agent( for call in &calls { let result = execute_tool(call); println!(" → {} chars returned", result.len()); - // Each tool result is a separate message, keyed by call.id messages.push(ChatMessage::tool_result(&call.id, &result)); } - // Continue: LLM will see the results and decide whether to call - // more tools or produce a final answer. } _ => { - // No tool calls → this is the terminal response. let answer = response.content.unwrap_or_default(); - println!("\n── done in {} iteration(s) ───────────────────", i + 1); - return Ok(answer); + return Ok(AgentOutput { + answer: Some(answer), + error: None, + messages, + iterations: i + 1, + duration_ms: start.elapsed().as_millis() as u64, + }); } } } - Err(format!( - "Agent did not converge within {MAX_ITERATIONS} iterations. \ - Consider increasing the limit or checking the system prompt." - ) - .into()) + // Convergence failure — still return messages for trace recording. + Ok(AgentOutput { + answer: None, + error: Some(format!( + "Agent did not converge within {MAX_ITERATIONS} iterations." + )), + messages, + iterations: MAX_ITERATIONS, + duration_ms: start.elapsed().as_millis() as u64, + }) } // ── Entry Point ─────────────────────────────────────────────────────────────── @@ -345,9 +330,57 @@ async fn main() -> Result<(), Box> { ), ]; - match run_agent(&client, &tools, messages).await { - Ok(answer) => println!("\n── final answer ──\n{answer}"), - Err(e) => eprintln!("\n── agent error ──\n{e}"), + // Capture the task before messages ownership moves into run_agent. + let task = messages + .iter() + .find(|m| m.role == "user") + .and_then(|m| m.content.as_deref()) + .unwrap_or("unknown task") + .to_string(); + + let output = run_agent(&client, &tools, messages).await?; + + // ── Print result ── + if let Some(ref answer) = output.answer { + println!("\n── final answer ──\n{answer}"); + } + if let Some(ref err) = output.error { + eprintln!("\n── agent error ──\n{err}"); + } + println!( + "\n[stats] {} iteration(s), {}ms, {} tool call(s)", + output.iterations, + output.duration_ms, + output + .messages + .iter() + .filter_map(|m| m.tool_calls.as_ref()) + .map(|c| c.len()) + .sum::() + ); + + // ── Save trace ── + let trace = if output.succeeded() { + Trace::success( + &task, + output.messages, + output.answer.unwrap(), + output.iterations, + output.duration_ms, + ) + } else { + Trace::failure( + &task, + output.messages, + output.error.unwrap_or_else(|| "unknown failure".into()), + output.iterations, + output.duration_ms, + ) + }; + + match trace.save() { + Ok(path) => println!("[trace] saved → {path}"), + Err(e) => eprintln!("[trace] failed to save: {e}"), } Ok(()) diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 0000000..d36e37b --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,157 @@ +use crate::ChatMessage; +use serde::{Deserialize, Serialize}; +use std::time::{SystemTime, UNIX_EPOCH}; + +// ── Status ──────────────────────────────────────────────────────────────────── + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum TraceStatus { + Success, + Failure, +} + +// ── Trace ───────────────────────────────────────────────────────────────────── +// +// A complete snapshot of one agent run. The `messages` field is the raw +// material for recipe extraction — it contains every turn in the conversation: +// system prompt, user task, all assistant turns (including tool_calls), all +// tool_result turns, and the final answer turn. +// +// Successful traces are candidates for recipe extraction. +// Failed traces are candidates for prompt/tool improvement. + +#[derive(Serialize, Deserialize, Debug)] +pub struct Trace { + /// Unix timestamp in milliseconds. Used as the unique ID and for ordering. + pub id: u64, + + /// ISO 8601 formatted timestamp of when the trace was created. + pub timestamp: String, + + /// The original user message that kicked off this run. + pub task: String, + + pub status: TraceStatus, + + pub iterations: usize, + + pub duration_ms: u64, + + /// Full conversation history. Order is always: + /// system → user → [assistant (+tool_calls) → tool_result(s)]* → assistant (final) + pub messages: Vec, + + /// The assistant's final text answer. None on failure. + pub final_answer: Option, + + /// The error string if the run failed. None on success. + pub error: Option, + + /// Deduplicated, ordered list of tool names actually invoked. + /// Extracted from assistant messages that had tool_calls. + /// Used by the recipe extractor to know which tools a recipe depends on. + pub tools_invoked: Vec, +} + +impl Trace { + pub fn success( + task: &str, + messages: Vec, + answer: String, + iterations: usize, + duration_ms: u64, + ) -> Self { + Self { + id: now_ms(), + timestamp: chrono::Utc::now().to_rfc3339(), + task: task.to_string(), + status: TraceStatus::Success, + iterations, + duration_ms, + tools_invoked: extract_tools(&messages), + messages, + final_answer: Some(answer), + error: None, + } + } + + pub fn failure( + task: &str, + messages: Vec, + error: String, + iterations: usize, + duration_ms: u64, + ) -> Self { + Self { + id: now_ms(), + timestamp: chrono::Utc::now().to_rfc3339(), + task: task.to_string(), + status: TraceStatus::Failure, + iterations, + duration_ms, + tools_invoked: extract_tools(&messages), + messages, + final_answer: None, + error: Some(error), + } + } + + /// Writes the trace to `traces/_.json`. + /// The directory is created if it doesn't exist. + pub fn save(&self) -> Result> { + std::fs::create_dir_all("traces")?; + let path = format!("traces/{}_{}.json", self.id, slugify(&self.task)); + std::fs::write(&path, serde_json::to_string_pretty(self)?)?; + Ok(path) + } + + /// How many tool calls were made across the entire run. + pub fn total_tool_calls(&self) -> usize { + self.messages + .iter() + .filter_map(|m| m.tool_calls.as_ref()) + .map(|calls| calls.len()) + .sum() + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn now_ms() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0) +} + +/// Converts the first 40 chars of a string into a filename-safe slug. +pub fn slugify(s: &str) -> String { + s.chars() + .filter(|c| c.is_alphanumeric() || c.is_whitespace()) + .take(40) + .map(|c| { + if c.is_whitespace() { + '_' + } else { + c.to_ascii_lowercase() + } + }) + .collect() +} + +/// Walks the message list and collects tool names in invocation order, +/// deduplicating while preserving first-seen order. +fn extract_tools(messages: &[ChatMessage]) -> Vec { + let mut seen: Vec = Vec::new(); + for msg in messages { + if let Some(calls) = &msg.tool_calls { + for call in calls { + if !seen.contains(&call.function.name) { + seen.push(call.function.name.clone()); + } + } + } + } + seen +}