refactor: simulate agent loop
This commit is contained in:
20
Cargo.lock
generated
20
Cargo.lock
generated
@@ -591,16 +591,6 @@ version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0"
|
||||
|
||||
[[package]]
|
||||
name = "llama_api_client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.14"
|
||||
@@ -828,6 +818,16 @@ dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ren"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.13.4"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "llama_api_client"
|
||||
name = "ren"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
|
||||
434
src/main.rs
434
src/main.rs
@@ -1,64 +1,86 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Serialize)]
|
||||
// ── Message Types ─────────────────────────────────────────────────────────────
|
||||
//
|
||||
// A single unified struct covers all four roles (system, user, assistant, tool).
|
||||
// Optional fields are skipped during serialization so the wire format stays clean.
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct JsonSchemaObject {
|
||||
name: String,
|
||||
schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
format_type: String, // "json_object" or "json_schema"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
json_schema: Option<JsonSchemaObject>,
|
||||
}
|
||||
content: Option<String>,
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] // <-- Add this line
|
||||
response_format: Option<ResponseFormat>, // <-- Wrap in Option
|
||||
// Only present on role: "tool" messages
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<Tool>>,
|
||||
}
|
||||
tool_call_id: Option<String>,
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
// Only present on role: "assistant" messages that invoke tools
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
impl ChatMessage {
|
||||
fn system(content: &str) -> Self {
|
||||
Self {
|
||||
role: "system".into(),
|
||||
content: Some(content.into()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn user(content: &str) -> Self {
|
||||
Self {
|
||||
role: "user".into(),
|
||||
content: Some(content.into()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Captures the assistant turn including any tool_calls it made.
|
||||
// This must be pushed into history before the tool result messages,
|
||||
// or most backends will reject the conversation.
|
||||
fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
|
||||
Self {
|
||||
role: "assistant".into(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
// One tool result message per tool call, keyed by the call's id.
|
||||
fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
||||
Self {
|
||||
role: "tool".into(),
|
||||
content: Some(content.into()),
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── API Wire Types ────────────────────────────────────────────────────────────
|
||||
|
||||
// Needs both Serialize + Deserialize: the model returns ToolCalls in its
|
||||
// response, and we push them straight back into the next request's messages.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct ToolCall {
|
||||
id: String,
|
||||
#[serde(rename = "type")]
|
||||
call_type: String, // Always "function"
|
||||
call_type: String, // always "function"
|
||||
function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
struct FunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
arguments: String, // JSON-encoded string from the model
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -71,99 +93,261 @@ struct ToolFunction {
|
||||
#[derive(Serialize)]
|
||||
struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // Always "function"
|
||||
tool_type: String,
|
||||
function: ToolFunction,
|
||||
}
|
||||
|
||||
use serde_json::json;
|
||||
use std::process::Command;
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionRequest<'a> {
|
||||
model: String,
|
||||
messages: &'a [ChatMessage],
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
// content is Option<String> because some models return null content
|
||||
// on assistant turns that contain only tool_calls.
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ResponseMessage {
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
// ── Tool Registry ─────────────────────────────────────────────────────────────
|
||||
//
|
||||
// This will eventually be loaded from a manifest file / plugin system.
|
||||
// For now it's a plain Vec that gets passed into every LLM request.
|
||||
|
||||
fn build_tools() -> Vec<Tool> {
|
||||
vec![
|
||||
Tool {
|
||||
tool_type: "function".into(),
|
||||
function: ToolFunction {
|
||||
name: "run_curl".into(),
|
||||
description: "Execute a network request to a public URL and return the response body.".into(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The full HTTP/HTTPS URL to query"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"enum": ["GET", "POST"],
|
||||
"default": "GET"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}),
|
||||
},
|
||||
},
|
||||
Tool {
|
||||
tool_type: "function".into(),
|
||||
function: ToolFunction {
|
||||
name: "run_shell".into(),
|
||||
description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
// ── Tool Execution ────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Each tool returns a String that gets fed back as a tool_result message.
|
||||
// Truncation is intentional: we don't want one noisy tool call consuming
|
||||
// the entire remaining context budget.
|
||||
|
||||
const MAX_TOOL_OUTPUT: usize = 2_000; // chars
|
||||
|
||||
fn truncate(s: String) -> String {
|
||||
if s.len() > MAX_TOOL_OUTPUT {
|
||||
format!(
|
||||
"{}…[truncated {} chars]",
|
||||
&s[..MAX_TOOL_OUTPUT],
|
||||
s.len() - MAX_TOOL_OUTPUT
|
||||
)
|
||||
} else {
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_tool(call: &ToolCall) -> String {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or(json!({}));
|
||||
|
||||
match call.function.name.as_str() {
|
||||
"run_curl" => {
|
||||
let url = args["url"].as_str().unwrap_or("");
|
||||
let method = args["method"].as_str().unwrap_or("GET");
|
||||
|
||||
println!(" [tool::run_curl] {} {}", method, url);
|
||||
|
||||
match Command::new("curl")
|
||||
.args(["-s", "-X", method, url])
|
||||
.output()
|
||||
{
|
||||
Ok(out) => truncate(String::from_utf8_lossy(&out.stdout).into_owned()),
|
||||
Err(e) => format!("[error] curl failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
"run_shell" => {
|
||||
let command = args["command"].as_str().unwrap_or("");
|
||||
|
||||
println!(" [tool::run_shell] {}", command);
|
||||
|
||||
match Command::new("sh").args(["-c", command]).output() {
|
||||
Ok(out) => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
|
||||
let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
|
||||
let combined = if stderr.is_empty() {
|
||||
stdout
|
||||
} else {
|
||||
format!("{stdout}\n[stderr]\n{stderr}")
|
||||
};
|
||||
truncate(combined)
|
||||
}
|
||||
Err(e) => format!("[error] shell failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
unknown => format!("[error] unknown tool: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── LLM Client ───────────────────────────────────────────────────────────────
|
||||
|
||||
async fn chat(
|
||||
client: &reqwest::Client,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[Tool],
|
||||
) -> Result<ResponseMessage, Box<dyn std::error::Error>> {
|
||||
let payload = ChatCompletionRequest {
|
||||
model: "local-model".to_string(),
|
||||
messages,
|
||||
temperature: 0.1,
|
||||
tools: Some(tools),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.post("http://localhost:8080/v1/chat/completions")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?
|
||||
.json::<ChatCompletionResponse>()
|
||||
.await?;
|
||||
|
||||
response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| "No choices in response".into())
|
||||
}
|
||||
|
||||
// ── Agent Loop ────────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Iteration contract:
|
||||
// 1. Send messages → LLM
|
||||
// 2. Push the assistant turn into history (always, even if it has tool_calls)
|
||||
// 3a. If tool_calls present: execute each, push tool_result messages, repeat.
|
||||
// 3b. If no tool_calls: the content is the final answer, return it.
|
||||
//
|
||||
// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should
|
||||
// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error
|
||||
// rather than silently returning a partial result.
|
||||
|
||||
const MAX_ITERATIONS: usize = 10;
|
||||
|
||||
async fn run_agent(
|
||||
client: &reqwest::Client,
|
||||
tools: &[Tool],
|
||||
mut messages: Vec<ChatMessage>,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
for i in 0..MAX_ITERATIONS {
|
||||
println!("\n── iteration {} ──────────────────────────────", i + 1);
|
||||
|
||||
let response = chat(client, &messages, tools).await?;
|
||||
|
||||
// Step 2: push the full assistant turn before adding tool results.
|
||||
// Backends that validate conversation structure will reject the request
|
||||
// if tool_result messages appear without a preceding assistant message
|
||||
// that contains the matching tool_call ids.
|
||||
messages.push(ChatMessage::assistant(
|
||||
response.content.clone(),
|
||||
response.tool_calls.clone(),
|
||||
));
|
||||
|
||||
match response.tool_calls {
|
||||
Some(calls) if !calls.is_empty() => {
|
||||
println!(" {} tool call(s) requested", calls.len());
|
||||
for call in &calls {
|
||||
let result = execute_tool(call);
|
||||
println!(" → {} chars returned", result.len());
|
||||
// Each tool result is a separate message, keyed by call.id
|
||||
messages.push(ChatMessage::tool_result(&call.id, &result));
|
||||
}
|
||||
// Continue: LLM will see the results and decide whether to call
|
||||
// more tools or produce a final answer.
|
||||
}
|
||||
_ => {
|
||||
// No tool calls → this is the terminal response.
|
||||
let answer = response.content.unwrap_or_default();
|
||||
println!("\n── done in {} iteration(s) ───────────────────", i + 1);
|
||||
return Ok(answer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!(
|
||||
"Agent did not converge within {MAX_ITERATIONS} iterations. \
|
||||
Consider increasing the limit or checking the system prompt."
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
// ── Entry Point ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
let tools = build_tools();
|
||||
|
||||
// Define the tool signature
|
||||
let curl_tool = Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "run_curl".to_string(),
|
||||
description:
|
||||
"Executes a curl network request command to fetch data from a public URL endpoint"
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string", "description": "The full HTTP/HTTPS URL endpoint to query" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" }
|
||||
},
|
||||
"required": ["url"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
let messages = vec![
|
||||
ChatMessage::system(
|
||||
"You are a concise system assistant. Use the provided tools to gather \
|
||||
real-time information when needed. Do not guess at live data.",
|
||||
),
|
||||
ChatMessage::user(
|
||||
"Check if wttr.in/Delhi is accessible and give me a one-line weather summary.",
|
||||
),
|
||||
];
|
||||
|
||||
let payload = ChatCompletionRequest {
|
||||
model: "local-model".to_string(),
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful system assistant. You have access to a run_curl tool. Use it whenever you want to real-time web information or status codes.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Check if wttr.in/Delhi is accessible right now using a GET request.".to_string(),
|
||||
},
|
||||
],
|
||||
temperature: 0.1,
|
||||
response_format: None, // <-- Pass None here since we are using tools right now
|
||||
tools: Some(vec![curl_tool]),
|
||||
};
|
||||
|
||||
let http_response = client
|
||||
.post("http://localhost:8080/v1/chat/completions")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let response: ChatCompletionResponse = http_response.json().await?;
|
||||
|
||||
let message = &response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or("No choices returned")?
|
||||
.message;
|
||||
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
for call in tool_calls {
|
||||
if call.function.name == "run_curl" {
|
||||
// Parse the arguments back into JSON
|
||||
let args: serde_json::Value = serde_json::from_str(&call.function.arguments)?;
|
||||
let url = args["url"].as_str().unwrap_or("");
|
||||
let method = args["method"].as_str().unwrap_or("GET");
|
||||
|
||||
println!(
|
||||
"\n[Assistant Triggered Tool]: Executing curl {} to {}",
|
||||
method, url
|
||||
);
|
||||
|
||||
// Run the actual system shell command safely via standard libraries
|
||||
let output = Command::new("curl")
|
||||
.arg("-s") // Silent mode
|
||||
.arg("-X")
|
||||
.arg(method)
|
||||
.arg(url)
|
||||
.output()?;
|
||||
|
||||
let result_string = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
println!(
|
||||
"\n[System Execution Output]:\n{}",
|
||||
result_string.lines().take(5).collect::<Vec<_>>().join("\n")
|
||||
);
|
||||
println!("... (truncated)");
|
||||
}
|
||||
}
|
||||
} else if !message.content.is_empty() {
|
||||
println!("\nAI Response: {}", message.content);
|
||||
match run_agent(&client, &tools, messages).await {
|
||||
Ok(answer) => println!("\n── final answer ──\n{answer}"),
|
||||
Err(e) => eprintln!("\n── agent error ──\n{e}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user