feat: Trace generation
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1 +1,11 @@
|
||||
/target
|
||||
/traces
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
99
Cargo.lock
generated
99
Cargo.lock
generated
@@ -2,12 +2,27 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.17.0"
|
||||
@@ -78,6 +93,19 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
|
||||
dependencies = [
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.58"
|
||||
@@ -162,7 +190,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -383,6 +411,30 @@ dependencies = [
|
||||
"windows-registry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.65"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_collections"
|
||||
version = "2.2.0"
|
||||
@@ -635,6 +687,15 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.4"
|
||||
@@ -822,6 +883,7 @@ dependencies = [
|
||||
name = "ren"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -1515,6 +1577,41 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb"
|
||||
dependencies = [
|
||||
"windows-implement",
|
||||
"windows-interface",
|
||||
"windows-link",
|
||||
"windows-result",
|
||||
"windows-strings",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-implement"
|
||||
version = "0.60.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-interface"
|
||||
version = "0.59.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
|
||||
@@ -4,6 +4,7 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
chrono = "0.4.44"
|
||||
reqwest = { version = "0.13.4", features = ["json"] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.150"
|
||||
|
||||
213
src/main.rs
213
src/main.rs
@@ -1,30 +1,29 @@
|
||||
mod trace;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::process::Command;
|
||||
use std::time::Instant;
|
||||
use trace::Trace;
|
||||
|
||||
// ── Message Types ─────────────────────────────────────────────────────────────
|
||||
//
|
||||
// A single unified struct covers all four roles (system, user, assistant, tool).
|
||||
// Optional fields are skipped during serialization so the wire format stays clean.
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
pub content: Option<String>,
|
||||
|
||||
// Only present on role: "tool" messages
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
// Only present on role: "assistant" messages that invoke tools
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
fn system(content: &str) -> Self {
|
||||
pub fn system(content: &str) -> Self {
|
||||
Self {
|
||||
role: "system".into(),
|
||||
content: Some(content.into()),
|
||||
@@ -32,8 +31,7 @@ impl ChatMessage {
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn user(content: &str) -> Self {
|
||||
pub fn user(content: &str) -> Self {
|
||||
Self {
|
||||
role: "user".into(),
|
||||
content: Some(content.into()),
|
||||
@@ -41,11 +39,7 @@ impl ChatMessage {
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Captures the assistant turn including any tool_calls it made.
|
||||
// This must be pushed into history before the tool result messages,
|
||||
// or most backends will reject the conversation.
|
||||
fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
|
||||
pub fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
|
||||
Self {
|
||||
role: "assistant".into(),
|
||||
content,
|
||||
@@ -53,9 +47,7 @@ impl ChatMessage {
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
// One tool result message per tool call, keyed by the call's id.
|
||||
fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
||||
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
||||
Self {
|
||||
role: "tool".into(),
|
||||
content: Some(content.into()),
|
||||
@@ -67,20 +59,18 @@ impl ChatMessage {
|
||||
|
||||
// ── API Wire Types ────────────────────────────────────────────────────────────
|
||||
|
||||
// Needs both Serialize + Deserialize: the model returns ToolCalls in its
|
||||
// response, and we push them straight back into the next request's messages.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct ToolCall {
|
||||
id: String,
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
call_type: String, // always "function"
|
||||
function: FunctionCall,
|
||||
pub call_type: String,
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct FunctionCall {
|
||||
name: String,
|
||||
arguments: String, // JSON-encoded string from the model
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -116,18 +106,39 @@ struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
// content is Option<String> because some models return null content
|
||||
// on assistant turns that contain only tool_calls.
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ResponseMessage {
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
// ── Tool Registry ─────────────────────────────────────────────────────────────
|
||||
// ── Agent Output ──────────────────────────────────────────────────────────────
|
||||
//
|
||||
// This will eventually be loaded from a manifest file / plugin system.
|
||||
// For now it's a plain Vec that gets passed into every LLM request.
|
||||
// run_agent always returns AgentOutput — even on convergence failure.
|
||||
// This guarantees the message history is always accessible for trace recording,
|
||||
// regardless of whether the run succeeded.
|
||||
//
|
||||
// The outer Result<_, _> is reserved for hard, unrecoverable failures:
|
||||
// network errors, JSON deserialization failures, etc.
|
||||
|
||||
pub struct AgentOutput {
|
||||
/// The assistant's final text response. None if the run failed to converge.
|
||||
pub answer: Option<String>,
|
||||
/// Set if the run hit MAX_ITERATIONS without a final answer.
|
||||
pub error: Option<String>,
|
||||
/// Full message history, always returned so the trace can be saved.
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub iterations: usize,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
impl AgentOutput {
|
||||
pub fn succeeded(&self) -> bool {
|
||||
self.error.is_none() && self.answer.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tool Registry ─────────────────────────────────────────────────────────────
|
||||
|
||||
fn build_tools() -> Vec<Tool> {
|
||||
vec![
|
||||
@@ -139,15 +150,8 @@ fn build_tools() -> Vec<Tool> {
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The full HTTP/HTTPS URL to query"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"enum": ["GET", "POST"],
|
||||
"default": "GET"
|
||||
}
|
||||
"url": { "type": "string", "description": "The full HTTP/HTTPS URL to query" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" }
|
||||
},
|
||||
"required": ["url"]
|
||||
}),
|
||||
@@ -157,14 +161,11 @@ fn build_tools() -> Vec<Tool> {
|
||||
tool_type: "function".into(),
|
||||
function: ToolFunction {
|
||||
name: "run_shell".into(),
|
||||
description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(),
|
||||
description: "Run a shell command and return stdout + stderr. Use for filesystem ops, text processing, etc.".into(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
"command": { "type": "string", "description": "The shell command to execute" }
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
@@ -174,12 +175,8 @@ fn build_tools() -> Vec<Tool> {
|
||||
}
|
||||
|
||||
// ── Tool Execution ────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Each tool returns a String that gets fed back as a tool_result message.
|
||||
// Truncation is intentional: we don't want one noisy tool call consuming
|
||||
// the entire remaining context budget.
|
||||
|
||||
const MAX_TOOL_OUTPUT: usize = 2_000; // chars
|
||||
const MAX_TOOL_OUTPUT: usize = 2_000;
|
||||
|
||||
fn truncate(s: String) -> String {
|
||||
if s.len() > MAX_TOOL_OUTPUT {
|
||||
@@ -201,9 +198,7 @@ fn execute_tool(call: &ToolCall) -> String {
|
||||
"run_curl" => {
|
||||
let url = args["url"].as_str().unwrap_or("");
|
||||
let method = args["method"].as_str().unwrap_or("GET");
|
||||
|
||||
println!(" [tool::run_curl] {} {}", method, url);
|
||||
|
||||
match Command::new("curl")
|
||||
.args(["-s", "-X", method, url])
|
||||
.output()
|
||||
@@ -212,12 +207,9 @@ fn execute_tool(call: &ToolCall) -> String {
|
||||
Err(e) => format!("[error] curl failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
"run_shell" => {
|
||||
let command = args["command"].as_str().unwrap_or("");
|
||||
|
||||
println!(" [tool::run_shell] {}", command);
|
||||
|
||||
match Command::new("sh").args(["-c", command]).output() {
|
||||
Ok(out) => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
|
||||
@@ -232,12 +224,11 @@ fn execute_tool(call: &ToolCall) -> String {
|
||||
Err(e) => format!("[error] shell failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
unknown => format!("[error] unknown tool: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── LLM Client ───────────────────────────────────────────────────────────────
|
||||
// ── LLM Client ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn chat(
|
||||
client: &reqwest::Client,
|
||||
@@ -268,16 +259,6 @@ async fn chat(
|
||||
}
|
||||
|
||||
// ── Agent Loop ────────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Iteration contract:
|
||||
// 1. Send messages → LLM
|
||||
// 2. Push the assistant turn into history (always, even if it has tool_calls)
|
||||
// 3a. If tool_calls present: execute each, push tool_result messages, repeat.
|
||||
// 3b. If no tool_calls: the content is the final answer, return it.
|
||||
//
|
||||
// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should
|
||||
// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error
|
||||
// rather than silently returning a partial result.
|
||||
|
||||
const MAX_ITERATIONS: usize = 10;
|
||||
|
||||
@@ -285,16 +266,14 @@ async fn run_agent(
|
||||
client: &reqwest::Client,
|
||||
tools: &[Tool],
|
||||
mut messages: Vec<ChatMessage>,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
) -> Result<AgentOutput, Box<dyn std::error::Error>> {
|
||||
let start = Instant::now();
|
||||
|
||||
for i in 0..MAX_ITERATIONS {
|
||||
println!("\n── iteration {} ──────────────────────────────", i + 1);
|
||||
|
||||
let response = chat(client, &messages, tools).await?;
|
||||
|
||||
// Step 2: push the full assistant turn before adding tool results.
|
||||
// Backends that validate conversation structure will reject the request
|
||||
// if tool_result messages appear without a preceding assistant message
|
||||
// that contains the matching tool_call ids.
|
||||
messages.push(ChatMessage::assistant(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
@@ -306,26 +285,32 @@ async fn run_agent(
|
||||
for call in &calls {
|
||||
let result = execute_tool(call);
|
||||
println!(" → {} chars returned", result.len());
|
||||
// Each tool result is a separate message, keyed by call.id
|
||||
messages.push(ChatMessage::tool_result(&call.id, &result));
|
||||
}
|
||||
// Continue: LLM will see the results and decide whether to call
|
||||
// more tools or produce a final answer.
|
||||
}
|
||||
_ => {
|
||||
// No tool calls → this is the terminal response.
|
||||
let answer = response.content.unwrap_or_default();
|
||||
println!("\n── done in {} iteration(s) ───────────────────", i + 1);
|
||||
return Ok(answer);
|
||||
return Ok(AgentOutput {
|
||||
answer: Some(answer),
|
||||
error: None,
|
||||
messages,
|
||||
iterations: i + 1,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!(
|
||||
"Agent did not converge within {MAX_ITERATIONS} iterations. \
|
||||
Consider increasing the limit or checking the system prompt."
|
||||
)
|
||||
.into())
|
||||
// Convergence failure — still return messages for trace recording.
|
||||
Ok(AgentOutput {
|
||||
answer: None,
|
||||
error: Some(format!(
|
||||
"Agent did not converge within {MAX_ITERATIONS} iterations."
|
||||
)),
|
||||
messages,
|
||||
iterations: MAX_ITERATIONS,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Entry Point ───────────────────────────────────────────────────────────────
|
||||
@@ -345,9 +330,57 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
),
|
||||
];
|
||||
|
||||
match run_agent(&client, &tools, messages).await {
|
||||
Ok(answer) => println!("\n── final answer ──\n{answer}"),
|
||||
Err(e) => eprintln!("\n── agent error ──\n{e}"),
|
||||
// Capture the task before messages ownership moves into run_agent.
|
||||
let task = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "user")
|
||||
.and_then(|m| m.content.as_deref())
|
||||
.unwrap_or("unknown task")
|
||||
.to_string();
|
||||
|
||||
let output = run_agent(&client, &tools, messages).await?;
|
||||
|
||||
// ── Print result ──
|
||||
if let Some(ref answer) = output.answer {
|
||||
println!("\n── final answer ──\n{answer}");
|
||||
}
|
||||
if let Some(ref err) = output.error {
|
||||
eprintln!("\n── agent error ──\n{err}");
|
||||
}
|
||||
println!(
|
||||
"\n[stats] {} iteration(s), {}ms, {} tool call(s)",
|
||||
output.iterations,
|
||||
output.duration_ms,
|
||||
output
|
||||
.messages
|
||||
.iter()
|
||||
.filter_map(|m| m.tool_calls.as_ref())
|
||||
.map(|c| c.len())
|
||||
.sum::<usize>()
|
||||
);
|
||||
|
||||
// ── Save trace ──
|
||||
let trace = if output.succeeded() {
|
||||
Trace::success(
|
||||
&task,
|
||||
output.messages,
|
||||
output.answer.unwrap(),
|
||||
output.iterations,
|
||||
output.duration_ms,
|
||||
)
|
||||
} else {
|
||||
Trace::failure(
|
||||
&task,
|
||||
output.messages,
|
||||
output.error.unwrap_or_else(|| "unknown failure".into()),
|
||||
output.iterations,
|
||||
output.duration_ms,
|
||||
)
|
||||
};
|
||||
|
||||
match trace.save() {
|
||||
Ok(path) => println!("[trace] saved → {path}"),
|
||||
Err(e) => eprintln!("[trace] failed to save: {e}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
157
src/trace.rs
Normal file
157
src/trace.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use crate::ChatMessage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
// ── Status ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TraceStatus {
|
||||
Success,
|
||||
Failure,
|
||||
}
|
||||
|
||||
// ── Trace ─────────────────────────────────────────────────────────────────────
|
||||
//
|
||||
// A complete snapshot of one agent run. The `messages` field is the raw
|
||||
// material for recipe extraction — it contains every turn in the conversation:
|
||||
// system prompt, user task, all assistant turns (including tool_calls), all
|
||||
// tool_result turns, and the final answer turn.
|
||||
//
|
||||
// Successful traces are candidates for recipe extraction.
|
||||
// Failed traces are candidates for prompt/tool improvement.
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Trace {
|
||||
/// Unix timestamp in milliseconds. Used as the unique ID and for ordering.
|
||||
pub id: u64,
|
||||
|
||||
/// ISO 8601 formatted timestamp of when the trace was created.
|
||||
pub timestamp: String,
|
||||
|
||||
/// The original user message that kicked off this run.
|
||||
pub task: String,
|
||||
|
||||
pub status: TraceStatus,
|
||||
|
||||
pub iterations: usize,
|
||||
|
||||
pub duration_ms: u64,
|
||||
|
||||
/// Full conversation history. Order is always:
|
||||
/// system → user → [assistant (+tool_calls) → tool_result(s)]* → assistant (final)
|
||||
pub messages: Vec<ChatMessage>,
|
||||
|
||||
/// The assistant's final text answer. None on failure.
|
||||
pub final_answer: Option<String>,
|
||||
|
||||
/// The error string if the run failed. None on success.
|
||||
pub error: Option<String>,
|
||||
|
||||
/// Deduplicated, ordered list of tool names actually invoked.
|
||||
/// Extracted from assistant messages that had tool_calls.
|
||||
/// Used by the recipe extractor to know which tools a recipe depends on.
|
||||
pub tools_invoked: Vec<String>,
|
||||
}
|
||||
|
||||
impl Trace {
|
||||
pub fn success(
|
||||
task: &str,
|
||||
messages: Vec<ChatMessage>,
|
||||
answer: String,
|
||||
iterations: usize,
|
||||
duration_ms: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: now_ms(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
task: task.to_string(),
|
||||
status: TraceStatus::Success,
|
||||
iterations,
|
||||
duration_ms,
|
||||
tools_invoked: extract_tools(&messages),
|
||||
messages,
|
||||
final_answer: Some(answer),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn failure(
|
||||
task: &str,
|
||||
messages: Vec<ChatMessage>,
|
||||
error: String,
|
||||
iterations: usize,
|
||||
duration_ms: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: now_ms(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
task: task.to_string(),
|
||||
status: TraceStatus::Failure,
|
||||
iterations,
|
||||
duration_ms,
|
||||
tools_invoked: extract_tools(&messages),
|
||||
messages,
|
||||
final_answer: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes the trace to `traces/<id>_<task_slug>.json`.
|
||||
/// The directory is created if it doesn't exist.
|
||||
pub fn save(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
std::fs::create_dir_all("traces")?;
|
||||
let path = format!("traces/{}_{}.json", self.id, slugify(&self.task));
|
||||
std::fs::write(&path, serde_json::to_string_pretty(self)?)?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// How many tool calls were made across the entire run.
|
||||
pub fn total_tool_calls(&self) -> usize {
|
||||
self.messages
|
||||
.iter()
|
||||
.filter_map(|m| m.tool_calls.as_ref())
|
||||
.map(|calls| calls.len())
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn now_ms() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_millis() as u64)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Converts the first 40 chars of a string into a filename-safe slug.
|
||||
pub fn slugify(s: &str) -> String {
|
||||
s.chars()
|
||||
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
|
||||
.take(40)
|
||||
.map(|c| {
|
||||
if c.is_whitespace() {
|
||||
'_'
|
||||
} else {
|
||||
c.to_ascii_lowercase()
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Walks the message list and collects tool names in invocation order,
|
||||
/// deduplicating while preserving first-seen order.
|
||||
fn extract_tools(messages: &[ChatMessage]) -> Vec<String> {
|
||||
let mut seen: Vec<String> = Vec::new();
|
||||
for msg in messages {
|
||||
if let Some(calls) = &msg.tool_calls {
|
||||
for call in calls {
|
||||
if !seen.contains(&call.function.name) {
|
||||
seen.push(call.function.name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
Reference in New Issue
Block a user