355 lines
12 KiB
Rust
355 lines
12 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use std::process::Command;
|
|
|
|
// ── Message Types ─────────────────────────────────────────────────────────────
|
|
//
|
|
// A single unified struct covers all four roles (system, user, assistant, tool).
|
|
// Optional fields are skipped during serialization so the wire format stays clean.
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
struct ChatMessage {
|
|
role: String,
|
|
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
content: Option<String>,
|
|
|
|
// Only present on role: "tool" messages
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_call_id: Option<String>,
|
|
|
|
// Only present on role: "assistant" messages that invoke tools
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_calls: Option<Vec<ToolCall>>,
|
|
}
|
|
|
|
impl ChatMessage {
|
|
fn system(content: &str) -> Self {
|
|
Self {
|
|
role: "system".into(),
|
|
content: Some(content.into()),
|
|
tool_call_id: None,
|
|
tool_calls: None,
|
|
}
|
|
}
|
|
|
|
fn user(content: &str) -> Self {
|
|
Self {
|
|
role: "user".into(),
|
|
content: Some(content.into()),
|
|
tool_call_id: None,
|
|
tool_calls: None,
|
|
}
|
|
}
|
|
|
|
// Captures the assistant turn including any tool_calls it made.
|
|
// This must be pushed into history before the tool result messages,
|
|
// or most backends will reject the conversation.
|
|
fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
|
|
Self {
|
|
role: "assistant".into(),
|
|
content,
|
|
tool_call_id: None,
|
|
tool_calls,
|
|
}
|
|
}
|
|
|
|
// One tool result message per tool call, keyed by the call's id.
|
|
fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
|
Self {
|
|
role: "tool".into(),
|
|
content: Some(content.into()),
|
|
tool_call_id: Some(tool_call_id.into()),
|
|
tool_calls: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── API Wire Types ────────────────────────────────────────────────────────────
|
|
|
|
// Needs both Serialize + Deserialize: the model returns ToolCalls in its
|
|
// response, and we push them straight back into the next request's messages.
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
struct ToolCall {
|
|
id: String,
|
|
#[serde(rename = "type")]
|
|
call_type: String, // always "function"
|
|
function: FunctionCall,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
struct FunctionCall {
|
|
name: String,
|
|
arguments: String, // JSON-encoded string from the model
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ToolFunction {
|
|
name: String,
|
|
description: String,
|
|
parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct Tool {
|
|
#[serde(rename = "type")]
|
|
tool_type: String,
|
|
function: ToolFunction,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ChatCompletionRequest<'a> {
|
|
model: String,
|
|
messages: &'a [ChatMessage],
|
|
temperature: f32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<&'a [Tool]>,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct ChatCompletionResponse {
|
|
choices: Vec<Choice>,
|
|
}
|
|
|
|
#[derive(Deserialize, Debug)]
|
|
struct Choice {
|
|
message: ResponseMessage,
|
|
}
|
|
|
|
// content is Option<String> because some models return null content
|
|
// on assistant turns that contain only tool_calls.
|
|
#[derive(Deserialize, Debug)]
|
|
struct ResponseMessage {
|
|
content: Option<String>,
|
|
tool_calls: Option<Vec<ToolCall>>,
|
|
}
|
|
|
|
// ── Tool Registry ─────────────────────────────────────────────────────────────
|
|
//
|
|
// This will eventually be loaded from a manifest file / plugin system.
|
|
// For now it's a plain Vec that gets passed into every LLM request.
|
|
|
|
fn build_tools() -> Vec<Tool> {
|
|
vec![
|
|
Tool {
|
|
tool_type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "run_curl".into(),
|
|
description: "Execute a network request to a public URL and return the response body.".into(),
|
|
parameters: json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"url": {
|
|
"type": "string",
|
|
"description": "The full HTTP/HTTPS URL to query"
|
|
},
|
|
"method": {
|
|
"type": "string",
|
|
"enum": ["GET", "POST"],
|
|
"default": "GET"
|
|
}
|
|
},
|
|
"required": ["url"]
|
|
}),
|
|
},
|
|
},
|
|
Tool {
|
|
tool_type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "run_shell".into(),
|
|
description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(),
|
|
parameters: json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {
|
|
"type": "string",
|
|
"description": "The shell command to execute"
|
|
}
|
|
},
|
|
"required": ["command"]
|
|
}),
|
|
},
|
|
},
|
|
]
|
|
}
|
|
|
|
// ── Tool Execution ────────────────────────────────────────────────────────────
|
|
//
|
|
// Each tool returns a String that gets fed back as a tool_result message.
|
|
// Truncation is intentional: we don't want one noisy tool call consuming
|
|
// the entire remaining context budget.
|
|
|
|
const MAX_TOOL_OUTPUT: usize = 2_000; // chars
|
|
|
|
fn truncate(s: String) -> String {
|
|
if s.len() > MAX_TOOL_OUTPUT {
|
|
format!(
|
|
"{}…[truncated {} chars]",
|
|
&s[..MAX_TOOL_OUTPUT],
|
|
s.len() - MAX_TOOL_OUTPUT
|
|
)
|
|
} else {
|
|
s
|
|
}
|
|
}
|
|
|
|
fn execute_tool(call: &ToolCall) -> String {
|
|
let args: serde_json::Value =
|
|
serde_json::from_str(&call.function.arguments).unwrap_or(json!({}));
|
|
|
|
match call.function.name.as_str() {
|
|
"run_curl" => {
|
|
let url = args["url"].as_str().unwrap_or("");
|
|
let method = args["method"].as_str().unwrap_or("GET");
|
|
|
|
println!(" [tool::run_curl] {} {}", method, url);
|
|
|
|
match Command::new("curl")
|
|
.args(["-s", "-X", method, url])
|
|
.output()
|
|
{
|
|
Ok(out) => truncate(String::from_utf8_lossy(&out.stdout).into_owned()),
|
|
Err(e) => format!("[error] curl failed: {e}"),
|
|
}
|
|
}
|
|
|
|
"run_shell" => {
|
|
let command = args["command"].as_str().unwrap_or("");
|
|
|
|
println!(" [tool::run_shell] {}", command);
|
|
|
|
match Command::new("sh").args(["-c", command]).output() {
|
|
Ok(out) => {
|
|
let stdout = String::from_utf8_lossy(&out.stdout).into_owned();
|
|
let stderr = String::from_utf8_lossy(&out.stderr).into_owned();
|
|
let combined = if stderr.is_empty() {
|
|
stdout
|
|
} else {
|
|
format!("{stdout}\n[stderr]\n{stderr}")
|
|
};
|
|
truncate(combined)
|
|
}
|
|
Err(e) => format!("[error] shell failed: {e}"),
|
|
}
|
|
}
|
|
|
|
unknown => format!("[error] unknown tool: {unknown}"),
|
|
}
|
|
}
|
|
|
|
// ── LLM Client ───────────────────────────────────────────────────────────────
|
|
|
|
async fn chat(
|
|
client: &reqwest::Client,
|
|
messages: &[ChatMessage],
|
|
tools: &[Tool],
|
|
) -> Result<ResponseMessage, Box<dyn std::error::Error>> {
|
|
let payload = ChatCompletionRequest {
|
|
model: "local-model".to_string(),
|
|
messages,
|
|
temperature: 0.1,
|
|
tools: Some(tools),
|
|
};
|
|
|
|
let response = client
|
|
.post("http://localhost:8080/v1/chat/completions")
|
|
.json(&payload)
|
|
.send()
|
|
.await?
|
|
.json::<ChatCompletionResponse>()
|
|
.await?;
|
|
|
|
response
|
|
.choices
|
|
.into_iter()
|
|
.next()
|
|
.map(|c| c.message)
|
|
.ok_or_else(|| "No choices in response".into())
|
|
}
|
|
|
|
// ── Agent Loop ────────────────────────────────────────────────────────────────
|
|
//
|
|
// Iteration contract:
|
|
// 1. Send messages → LLM
|
|
// 2. Push the assistant turn into history (always, even if it has tool_calls)
|
|
// 3a. If tool_calls present: execute each, push tool_result messages, repeat.
|
|
// 3b. If no tool_calls: the content is the final answer, return it.
|
|
//
|
|
// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should
|
|
// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error
|
|
// rather than silently returning a partial result.
|
|
|
|
const MAX_ITERATIONS: usize = 10;
|
|
|
|
async fn run_agent(
|
|
client: &reqwest::Client,
|
|
tools: &[Tool],
|
|
mut messages: Vec<ChatMessage>,
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|
for i in 0..MAX_ITERATIONS {
|
|
println!("\n── iteration {} ──────────────────────────────", i + 1);
|
|
|
|
let response = chat(client, &messages, tools).await?;
|
|
|
|
// Step 2: push the full assistant turn before adding tool results.
|
|
// Backends that validate conversation structure will reject the request
|
|
// if tool_result messages appear without a preceding assistant message
|
|
// that contains the matching tool_call ids.
|
|
messages.push(ChatMessage::assistant(
|
|
response.content.clone(),
|
|
response.tool_calls.clone(),
|
|
));
|
|
|
|
match response.tool_calls {
|
|
Some(calls) if !calls.is_empty() => {
|
|
println!(" {} tool call(s) requested", calls.len());
|
|
for call in &calls {
|
|
let result = execute_tool(call);
|
|
println!(" → {} chars returned", result.len());
|
|
// Each tool result is a separate message, keyed by call.id
|
|
messages.push(ChatMessage::tool_result(&call.id, &result));
|
|
}
|
|
// Continue: LLM will see the results and decide whether to call
|
|
// more tools or produce a final answer.
|
|
}
|
|
_ => {
|
|
// No tool calls → this is the terminal response.
|
|
let answer = response.content.unwrap_or_default();
|
|
println!("\n── done in {} iteration(s) ───────────────────", i + 1);
|
|
return Ok(answer);
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(format!(
|
|
"Agent did not converge within {MAX_ITERATIONS} iterations. \
|
|
Consider increasing the limit or checking the system prompt."
|
|
)
|
|
.into())
|
|
}
|
|
|
|
// ── Entry Point ───────────────────────────────────────────────────────────────
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
let client = reqwest::Client::new();
|
|
let tools = build_tools();
|
|
|
|
let messages = vec![
|
|
ChatMessage::system(
|
|
"You are a concise system assistant. Use the provided tools to gather \
|
|
real-time information when needed. Do not guess at live data.",
|
|
),
|
|
ChatMessage::user(
|
|
"Check if wttr.in/Delhi is accessible and give me a one-line weather summary.",
|
|
),
|
|
];
|
|
|
|
match run_agent(&client, &tools, messages).await {
|
|
Ok(answer) => println!("\n── final answer ──\n{answer}"),
|
|
Err(e) => eprintln!("\n── agent error ──\n{e}"),
|
|
}
|
|
|
|
Ok(())
|
|
}
|