refactor: simulate agent loop

This commit is contained in:
2026-06-04 16:52:06 +05:30
parent 3da8c0a70b
commit 5138de5a1c
3 changed files with 320 additions and 136 deletions

20
Cargo.lock generated
View File

@@ -591,16 +591,6 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0"
[[package]]
name = "llama_api_client"
version = "0.1.0"
dependencies = [
"reqwest",
"serde",
"serde_json",
"tokio",
]
[[package]]
name = "lock_api"
version = "0.4.14"
@@ -828,6 +818,16 @@ dependencies = [
"bitflags",
]
[[package]]
name = "ren"
version = "0.1.0"
dependencies = [
"reqwest",
"serde",
"serde_json",
"tokio",
]
[[package]]
name = "reqwest"
version = "0.13.4"

View File

@@ -1,5 +1,5 @@
[package]
name = "llama_api_client"
name = "ren"
version = "0.1.0"
edition = "2024"

View File

@@ -1,64 +1,86 @@
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::process::Command;
#[derive(Serialize)]
// ── Message Types ─────────────────────────────────────────────────────────────
//
// A single unified struct covers all four roles (system, user, assistant, tool).
// Optional fields are skipped during serialization so the wire format stays clean.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Serialize)]
struct JsonSchemaObject {
name: String,
schema: serde_json::Value,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String, // "json_object" or "json_schema"
#[serde(skip_serializing_if = "Option::is_none")]
json_schema: Option<JsonSchemaObject>,
}
content: Option<String>,
#[derive(Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")] // <-- Add this line
response_format: Option<ResponseFormat>, // <-- Wrap in Option
// Only present on role: "tool" messages
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<Tool>>,
}
tool_call_id: Option<String>,
#[derive(Deserialize, Debug)]
struct Choice {
message: ResponseMessage,
}
#[derive(Deserialize, Debug)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize, Debug)]
struct ResponseMessage {
content: String,
// Only present on role: "assistant" messages that invoke tools
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Deserialize, Debug)]
impl ChatMessage {
fn system(content: &str) -> Self {
Self {
role: "system".into(),
content: Some(content.into()),
tool_call_id: None,
tool_calls: None,
}
}
fn user(content: &str) -> Self {
Self {
role: "user".into(),
content: Some(content.into()),
tool_call_id: None,
tool_calls: None,
}
}
// Captures the assistant turn including any tool_calls it made.
// This must be pushed into history before the tool result messages,
// or most backends will reject the conversation.
fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
Self {
role: "assistant".into(),
content,
tool_call_id: None,
tool_calls,
}
}
// One tool result message per tool call, keyed by the call's id.
fn tool_result(tool_call_id: &str, content: &str) -> Self {
Self {
role: "tool".into(),
content: Some(content.into()),
tool_call_id: Some(tool_call_id.into()),
tool_calls: None,
}
}
}
// ── API Wire Types ────────────────────────────────────────────────────────────
// Needs both Serialize + Deserialize: the model returns ToolCalls in its
// response, and we push them straight back into the next request's messages.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct ToolCall {
id: String,
#[serde(rename = "type")]
call_type: String, // Always "function"
call_type: String, // always "function"
function: FunctionCall,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
struct FunctionCall {
name: String,
arguments: String,
arguments: String, // JSON-encoded string from the model
}
#[derive(Serialize)]
@@ -71,99 +93,261 @@ struct ToolFunction {
#[derive(Serialize)]
struct Tool {
#[serde(rename = "type")]
tool_type: String, // Always "function"
tool_type: String,
function: ToolFunction,
}
use serde_json::json;
use std::process::Command;
#[derive(Serialize)]
struct ChatCompletionRequest<'a> {
model: String,
messages: &'a [ChatMessage],
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
}
#[derive(Deserialize, Debug)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize, Debug)]
struct Choice {
message: ResponseMessage,
}
// content is Option<String> because some models return null content
// on assistant turns that contain only tool_calls.
#[derive(Deserialize, Debug)]
struct ResponseMessage {
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
}
// ── Tool Registry ─────────────────────────────────────────────────────────────
//
// This will eventually be loaded from a manifest file / plugin system.
// For now it's a plain Vec that gets passed into every LLM request.
fn build_tools() -> Vec<Tool> {
vec![
Tool {
tool_type: "function".into(),
function: ToolFunction {
name: "run_curl".into(),
description: "Execute a network request to a public URL and return the response body.".into(),
parameters: json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The full HTTP/HTTPS URL to query"
},
"method": {
"type": "string",
"enum": ["GET", "POST"],
"default": "GET"
}
},
"required": ["url"]
}),
},
},
Tool {
tool_type: "function".into(),
function: ToolFunction {
name: "run_shell".into(),
description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(),
parameters: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
}),
},
},
]
}
// ── Tool Execution ────────────────────────────────────────────────────────────
//
// Each tool returns a String that gets fed back as a tool_result message.
// Truncation is intentional: we don't want one noisy tool call consuming
// the entire remaining context budget.
const MAX_TOOL_OUTPUT: usize = 2_000; // chars
fn truncate(s: String) -> String {
if s.len() > MAX_TOOL_OUTPUT {
format!(
"{}…[truncated {} chars]",
&s[..MAX_TOOL_OUTPUT],
s.len() - MAX_TOOL_OUTPUT
)
} else {
s
}
}
fn execute_tool(call: &ToolCall) -> String {
let args: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or(json!({}));
match call.function.name.as_str() {
"run_curl" => {
let url = args["url"].as_str().unwrap_or("");
let method = args["method"].as_str().unwrap_or("GET");
println!(" [tool::run_curl] {} {}", method, url);
match Command::new("curl")
.args(["-s", "-X", method, url])
.output()
{
Ok(out) => truncate(String::from_utf8_lossy(&out.stdout).into_owned()),
Err(e) => format!("[error] curl failed: {e}"),
}
}
"run_shell" => {
let command = args["command"].as_str().unwrap_or("");
println!(" [tool::run_shell] {}", command);
match Command::new("sh").args(["-c", command]).output() {
Ok(out) => {
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
let combined = if stderr.is_empty() {
stdout
} else {
format!("{stdout}\n[stderr]\n{stderr}")
};
truncate(combined)
}
Err(e) => format!("[error] shell failed: {e}"),
}
}
unknown => format!("[error] unknown tool: {unknown}"),
}
}
// ── LLM Client ───────────────────────────────────────────────────────────────
async fn chat(
client: &reqwest::Client,
messages: &[ChatMessage],
tools: &[Tool],
) -> Result<ResponseMessage, Box<dyn std::error::Error>> {
let payload = ChatCompletionRequest {
model: "local-model".to_string(),
messages,
temperature: 0.1,
tools: Some(tools),
};
let response = client
.post("http://localhost:8080/v1/chat/completions")
.json(&payload)
.send()
.await?
.json::<ChatCompletionResponse>()
.await?;
response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| "No choices in response".into())
}
// ── Agent Loop ────────────────────────────────────────────────────────────────
//
// Iteration contract:
// 1. Send messages → LLM
// 2. Push the assistant turn into history (always, even if it has tool_calls)
// 3a. If tool_calls present: execute each, push tool_result messages, repeat.
// 3b. If no tool_calls: the content is the final answer, return it.
//
// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should
// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error
// rather than silently returning a partial result.
const MAX_ITERATIONS: usize = 10;
async fn run_agent(
client: &reqwest::Client,
tools: &[Tool],
mut messages: Vec<ChatMessage>,
) -> Result<String, Box<dyn std::error::Error>> {
for i in 0..MAX_ITERATIONS {
println!("\n── iteration {} ──────────────────────────────", i + 1);
let response = chat(client, &messages, tools).await?;
// Step 2: push the full assistant turn before adding tool results.
// Backends that validate conversation structure will reject the request
// if tool_result messages appear without a preceding assistant message
// that contains the matching tool_call ids.
messages.push(ChatMessage::assistant(
response.content.clone(),
response.tool_calls.clone(),
));
match response.tool_calls {
Some(calls) if !calls.is_empty() => {
println!(" {} tool call(s) requested", calls.len());
for call in &calls {
let result = execute_tool(call);
println!("{} chars returned", result.len());
// Each tool result is a separate message, keyed by call.id
messages.push(ChatMessage::tool_result(&call.id, &result));
}
// Continue: LLM will see the results and decide whether to call
// more tools or produce a final answer.
}
_ => {
// No tool calls → this is the terminal response.
let answer = response.content.unwrap_or_default();
println!("\n── done in {} iteration(s) ───────────────────", i + 1);
return Ok(answer);
}
}
}
Err(format!(
"Agent did not converge within {MAX_ITERATIONS} iterations. \
Consider increasing the limit or checking the system prompt."
)
.into())
}
// ── Entry Point ───────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let tools = build_tools();
// Define the tool signature
let curl_tool = Tool {
tool_type: "function".to_string(),
function: ToolFunction {
name: "run_curl".to_string(),
description:
"Executes a curl network request command to fetch data from a public URL endpoint"
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"url": { "type": "string", "description": "The full HTTP/HTTPS URL endpoint to query" },
"method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" }
},
"required": ["url"]
}),
},
};
let messages = vec![
ChatMessage::system(
"You are a concise system assistant. Use the provided tools to gather \
real-time information when needed. Do not guess at live data.",
),
ChatMessage::user(
"Check if wttr.in/Delhi is accessible and give me a one-line weather summary.",
),
];
let payload = ChatCompletionRequest {
model: "local-model".to_string(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "You are a helpful system assistant. You have access to a run_curl tool. Use it whenever you want to real-time web information or status codes.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Check if wttr.in/Delhi is accessible right now using a GET request.".to_string(),
},
],
temperature: 0.1,
response_format: None, // <-- Pass None here since we are using tools right now
tools: Some(vec![curl_tool]),
};
let http_response = client
.post("http://localhost:8080/v1/chat/completions")
.json(&payload)
.send()
.await?;
let response: ChatCompletionResponse = http_response.json().await?;
let message = &response
.choices
.first()
.ok_or("No choices returned")?
.message;
if let Some(tool_calls) = &message.tool_calls {
for call in tool_calls {
if call.function.name == "run_curl" {
// Parse the arguments back into JSON
let args: serde_json::Value = serde_json::from_str(&call.function.arguments)?;
let url = args["url"].as_str().unwrap_or("");
let method = args["method"].as_str().unwrap_or("GET");
println!(
"\n[Assistant Triggered Tool]: Executing curl {} to {}",
method, url
);
// Run the actual system shell command safely via standard libraries
let output = Command::new("curl")
.arg("-s") // Silent mode
.arg("-X")
.arg(method)
.arg(url)
.output()?;
let result_string = String::from_utf8_lossy(&output.stdout);
println!(
"\n[System Execution Output]:\n{}",
result_string.lines().take(5).collect::<Vec<_>>().join("\n")
);
println!("... (truncated)");
}
}
} else if !message.content.is_empty() {
println!("\nAI Response: {}", message.content);
match run_agent(&client, &tools, messages).await {
Ok(answer) => println!("\n── final answer ──\n{answer}"),
Err(e) => eprintln!("\n── agent error ──\n{e}"),
}
Ok(())