Files
ren/src/main.rs

355 lines
12 KiB
Rust

use serde::{Deserialize, Serialize};
use serde_json::json;
use std::process::Command;
// ── Message Types ─────────────────────────────────────────────────────────────
//
// A single unified struct covers all four roles (system, user, assistant, tool).
// Optional fields are skipped during serialization so the wire format stays clean.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct ChatMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
// Only present on role: "tool" messages
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
// Only present on role: "assistant" messages that invoke tools
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
}
impl ChatMessage {
fn system(content: &str) -> Self {
Self {
role: "system".into(),
content: Some(content.into()),
tool_call_id: None,
tool_calls: None,
}
}
fn user(content: &str) -> Self {
Self {
role: "user".into(),
content: Some(content.into()),
tool_call_id: None,
tool_calls: None,
}
}
// Captures the assistant turn including any tool_calls it made.
// This must be pushed into history before the tool result messages,
// or most backends will reject the conversation.
fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
Self {
role: "assistant".into(),
content,
tool_call_id: None,
tool_calls,
}
}
// One tool result message per tool call, keyed by the call's id.
fn tool_result(tool_call_id: &str, content: &str) -> Self {
Self {
role: "tool".into(),
content: Some(content.into()),
tool_call_id: Some(tool_call_id.into()),
tool_calls: None,
}
}
}
// ── API Wire Types ────────────────────────────────────────────────────────────
// Needs both Serialize + Deserialize: the model returns ToolCalls in its
// response, and we push them straight back into the next request's messages.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct ToolCall {
id: String,
#[serde(rename = "type")]
call_type: String, // always "function"
function: FunctionCall,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct FunctionCall {
name: String,
arguments: String, // JSON-encoded string from the model
}
#[derive(Serialize)]
struct ToolFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Serialize)]
struct Tool {
#[serde(rename = "type")]
tool_type: String,
function: ToolFunction,
}
#[derive(Serialize)]
struct ChatCompletionRequest<'a> {
model: String,
messages: &'a [ChatMessage],
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
}
#[derive(Deserialize, Debug)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize, Debug)]
struct Choice {
message: ResponseMessage,
}
// content is Option<String> because some models return null content
// on assistant turns that contain only tool_calls.
#[derive(Deserialize, Debug)]
struct ResponseMessage {
content: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
}
// ── Tool Registry ─────────────────────────────────────────────────────────────
//
// This will eventually be loaded from a manifest file / plugin system.
// For now it's a plain Vec that gets passed into every LLM request.
fn build_tools() -> Vec<Tool> {
vec![
Tool {
tool_type: "function".into(),
function: ToolFunction {
name: "run_curl".into(),
description: "Execute a network request to a public URL and return the response body.".into(),
parameters: json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The full HTTP/HTTPS URL to query"
},
"method": {
"type": "string",
"enum": ["GET", "POST"],
"default": "GET"
}
},
"required": ["url"]
}),
},
},
Tool {
tool_type: "function".into(),
function: ToolFunction {
name: "run_shell".into(),
description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(),
parameters: json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
}),
},
},
]
}
// ── Tool Execution ────────────────────────────────────────────────────────────
//
// Each tool returns a String that gets fed back as a tool_result message.
// Truncation is intentional: we don't want one noisy tool call consuming
// the entire remaining context budget.
const MAX_TOOL_OUTPUT: usize = 2_000; // chars
fn truncate(s: String) -> String {
if s.len() > MAX_TOOL_OUTPUT {
format!(
"{}…[truncated {} chars]",
&s[..MAX_TOOL_OUTPUT],
s.len() - MAX_TOOL_OUTPUT
)
} else {
s
}
}
fn execute_tool(call: &ToolCall) -> String {
let args: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or(json!({}));
match call.function.name.as_str() {
"run_curl" => {
let url = args["url"].as_str().unwrap_or("");
let method = args["method"].as_str().unwrap_or("GET");
println!(" [tool::run_curl] {} {}", method, url);
match Command::new("curl")
.args(["-s", "-X", method, url])
.output()
{
Ok(out) => truncate(String::from_utf8_lossy(&out.stdout).into_owned()),
Err(e) => format!("[error] curl failed: {e}"),
}
}
"run_shell" => {
let command = args["command"].as_str().unwrap_or("");
println!(" [tool::run_shell] {}", command);
match Command::new("sh").args(["-c", command]).output() {
Ok(out) => {
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
let combined = if stderr.is_empty() {
stdout
} else {
format!("{stdout}\n[stderr]\n{stderr}")
};
truncate(combined)
}
Err(e) => format!("[error] shell failed: {e}"),
}
}
unknown => format!("[error] unknown tool: {unknown}"),
}
}
// ── LLM Client ───────────────────────────────────────────────────────────────
async fn chat(
client: &reqwest::Client,
messages: &[ChatMessage],
tools: &[Tool],
) -> Result<ResponseMessage, Box<dyn std::error::Error>> {
let payload = ChatCompletionRequest {
model: "local-model".to_string(),
messages,
temperature: 0.1,
tools: Some(tools),
};
let response = client
.post("http://localhost:8080/v1/chat/completions")
.json(&payload)
.send()
.await?
.json::<ChatCompletionResponse>()
.await?;
response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| "No choices in response".into())
}
// ── Agent Loop ────────────────────────────────────────────────────────────────
//
// Iteration contract:
// 1. Send messages → LLM
// 2. Push the assistant turn into history (always, even if it has tool_calls)
// 3a. If tool_calls present: execute each, push tool_result messages, repeat.
// 3b. If no tool_calls: the content is the final answer, return it.
//
// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should
// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error
// rather than silently returning a partial result.
const MAX_ITERATIONS: usize = 10;
async fn run_agent(
client: &reqwest::Client,
tools: &[Tool],
mut messages: Vec<ChatMessage>,
) -> Result<String, Box<dyn std::error::Error>> {
for i in 0..MAX_ITERATIONS {
println!("\n── iteration {} ──────────────────────────────", i + 1);
let response = chat(client, &messages, tools).await?;
// Step 2: push the full assistant turn before adding tool results.
// Backends that validate conversation structure will reject the request
// if tool_result messages appear without a preceding assistant message
// that contains the matching tool_call ids.
messages.push(ChatMessage::assistant(
response.content.clone(),
response.tool_calls.clone(),
));
match response.tool_calls {
Some(calls) if !calls.is_empty() => {
println!(" {} tool call(s) requested", calls.len());
for call in &calls {
let result = execute_tool(call);
println!("{} chars returned", result.len());
// Each tool result is a separate message, keyed by call.id
messages.push(ChatMessage::tool_result(&call.id, &result));
}
// Continue: LLM will see the results and decide whether to call
// more tools or produce a final answer.
}
_ => {
// No tool calls → this is the terminal response.
let answer = response.content.unwrap_or_default();
println!("\n── done in {} iteration(s) ───────────────────", i + 1);
return Ok(answer);
}
}
}
Err(format!(
"Agent did not converge within {MAX_ITERATIONS} iterations. \
Consider increasing the limit or checking the system prompt."
)
.into())
}
// ── Entry Point ───────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let tools = build_tools();
let messages = vec![
ChatMessage::system(
"You are a concise system assistant. Use the provided tools to gather \
real-time information when needed. Do not guess at live data.",
),
ChatMessage::user(
"Check if wttr.in/Delhi is accessible and give me a one-line weather summary.",
),
];
match run_agent(&client, &tools, messages).await {
Ok(answer) => println!("\n── final answer ──\n{answer}"),
Err(e) => eprintln!("\n── agent error ──\n{e}"),
}
Ok(())
}