diff --git a/Cargo.lock b/Cargo.lock index 32aa1c7..42b0ab1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -591,16 +591,6 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" -[[package]] -name = "llama_api_client" -version = "0.1.0" -dependencies = [ - "reqwest", - "serde", - "serde_json", - "tokio", -] - [[package]] name = "lock_api" version = "0.4.14" @@ -828,6 +818,16 @@ dependencies = [ "bitflags", ] +[[package]] +name = "ren" +version = "0.1.0" +dependencies = [ + "reqwest", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "reqwest" version = "0.13.4" diff --git a/Cargo.toml b/Cargo.toml index 7cd8a4f..18c94b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "llama_api_client" +name = "ren" version = "0.1.0" edition = "2024" diff --git a/src/main.rs b/src/main.rs index 51b0ebf..a3888fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,64 +1,86 @@ use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::process::Command; -#[derive(Serialize)] +// ── Message Types ───────────────────────────────────────────────────────────── +// +// A single unified struct covers all four roles (system, user, assistant, tool). +// Optional fields are skipped during serialization so the wire format stays clean. + +#[derive(Serialize, Deserialize, Debug, Clone)] struct ChatMessage { role: String, - content: String, -} -#[derive(Serialize)] -struct JsonSchemaObject { - name: String, - schema: serde_json::Value, -} - -#[derive(Serialize)] -struct ResponseFormat { - #[serde(rename = "type")] - format_type: String, // "json_object" or "json_schema" #[serde(skip_serializing_if = "Option::is_none")] - json_schema: Option, -} + content: Option, -#[derive(Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] // <-- Add this line - response_format: Option, // <-- Wrap in Option + // Only present on role: "tool" messages #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, -} + tool_call_id: Option, -#[derive(Deserialize, Debug)] -struct Choice { - message: ResponseMessage, -} - -#[derive(Deserialize, Debug)] -struct ChatCompletionResponse { - choices: Vec, -} - -#[derive(Deserialize, Debug)] -struct ResponseMessage { - content: String, + // Only present on role: "assistant" messages that invoke tools + #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, } -#[derive(Deserialize, Debug)] +impl ChatMessage { + fn system(content: &str) -> Self { + Self { + role: "system".into(), + content: Some(content.into()), + tool_call_id: None, + tool_calls: None, + } + } + + fn user(content: &str) -> Self { + Self { + role: "user".into(), + content: Some(content.into()), + tool_call_id: None, + tool_calls: None, + } + } + + // Captures the assistant turn including any tool_calls it made. + // This must be pushed into history before the tool result messages, + // or most backends will reject the conversation. + fn assistant(content: Option, tool_calls: Option>) -> Self { + Self { + role: "assistant".into(), + content, + tool_call_id: None, + tool_calls, + } + } + + // One tool result message per tool call, keyed by the call's id. + fn tool_result(tool_call_id: &str, content: &str) -> Self { + Self { + role: "tool".into(), + content: Some(content.into()), + tool_call_id: Some(tool_call_id.into()), + tool_calls: None, + } + } +} + +// ── API Wire Types ──────────────────────────────────────────────────────────── + +// Needs both Serialize + Deserialize: the model returns ToolCalls in its +// response, and we push them straight back into the next request's messages. +#[derive(Serialize, Deserialize, Debug, Clone)] struct ToolCall { id: String, #[serde(rename = "type")] - call_type: String, // Always "function" + call_type: String, // always "function" function: FunctionCall, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] struct FunctionCall { name: String, - arguments: String, + arguments: String, // JSON-encoded string from the model } #[derive(Serialize)] @@ -71,99 +93,261 @@ struct ToolFunction { #[derive(Serialize)] struct Tool { #[serde(rename = "type")] - tool_type: String, // Always "function" + tool_type: String, function: ToolFunction, } -use serde_json::json; -use std::process::Command; +#[derive(Serialize)] +struct ChatCompletionRequest<'a> { + model: String, + messages: &'a [ChatMessage], + temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option<&'a [Tool]>, +} + +#[derive(Deserialize, Debug)] +struct ChatCompletionResponse { + choices: Vec, +} + +#[derive(Deserialize, Debug)] +struct Choice { + message: ResponseMessage, +} + +// content is Option because some models return null content +// on assistant turns that contain only tool_calls. +#[derive(Deserialize, Debug)] +struct ResponseMessage { + content: Option, + tool_calls: Option>, +} + +// ── Tool Registry ───────────────────────────────────────────────────────────── +// +// This will eventually be loaded from a manifest file / plugin system. +// For now it's a plain Vec that gets passed into every LLM request. + +fn build_tools() -> Vec { + vec![ + Tool { + tool_type: "function".into(), + function: ToolFunction { + name: "run_curl".into(), + description: "Execute a network request to a public URL and return the response body.".into(), + parameters: json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The full HTTP/HTTPS URL to query" + }, + "method": { + "type": "string", + "enum": ["GET", "POST"], + "default": "GET" + } + }, + "required": ["url"] + }), + }, + }, + Tool { + tool_type: "function".into(), + function: ToolFunction { + name: "run_shell".into(), + description: "Run a single shell command and return stdout. Use for filesystem ops, grepping, listing files, etc.".into(), + parameters: json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + } + }, + "required": ["command"] + }), + }, + }, + ] +} + +// ── Tool Execution ──────────────────────────────────────────────────────────── +// +// Each tool returns a String that gets fed back as a tool_result message. +// Truncation is intentional: we don't want one noisy tool call consuming +// the entire remaining context budget. + +const MAX_TOOL_OUTPUT: usize = 2_000; // chars + +fn truncate(s: String) -> String { + if s.len() > MAX_TOOL_OUTPUT { + format!( + "{}…[truncated {} chars]", + &s[..MAX_TOOL_OUTPUT], + s.len() - MAX_TOOL_OUTPUT + ) + } else { + s + } +} + +fn execute_tool(call: &ToolCall) -> String { + let args: serde_json::Value = + serde_json::from_str(&call.function.arguments).unwrap_or(json!({})); + + match call.function.name.as_str() { + "run_curl" => { + let url = args["url"].as_str().unwrap_or(""); + let method = args["method"].as_str().unwrap_or("GET"); + + println!(" [tool::run_curl] {} {}", method, url); + + match Command::new("curl") + .args(["-s", "-X", method, url]) + .output() + { + Ok(out) => truncate(String::from_utf8_lossy(&out.stdout).into_owned()), + Err(e) => format!("[error] curl failed: {e}"), + } + } + + "run_shell" => { + let command = args["command"].as_str().unwrap_or(""); + + println!(" [tool::run_shell] {}", command); + + match Command::new("sh").args(["-c", command]).output() { + Ok(out) => { + let stdout = String::from_utf8_lossy(&out.stdout).into_owned(); + let stderr = String::from_utf8_lossy(&out.stderr).into_owned(); + let combined = if stderr.is_empty() { + stdout + } else { + format!("{stdout}\n[stderr]\n{stderr}") + }; + truncate(combined) + } + Err(e) => format!("[error] shell failed: {e}"), + } + } + + unknown => format!("[error] unknown tool: {unknown}"), + } +} + +// ── LLM Client ─────────────────────────────────────────────────────────────── + +async fn chat( + client: &reqwest::Client, + messages: &[ChatMessage], + tools: &[Tool], +) -> Result> { + let payload = ChatCompletionRequest { + model: "local-model".to_string(), + messages, + temperature: 0.1, + tools: Some(tools), + }; + + let response = client + .post("http://localhost:8080/v1/chat/completions") + .json(&payload) + .send() + .await? + .json::() + .await?; + + response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| "No choices in response".into()) +} + +// ── Agent Loop ──────────────────────────────────────────────────────────────── +// +// Iteration contract: +// 1. Send messages → LLM +// 2. Push the assistant turn into history (always, even if it has tool_calls) +// 3a. If tool_calls present: execute each, push tool_result messages, repeat. +// 3b. If no tool_calls: the content is the final answer, return it. +// +// MAX_ITERATIONS is a hard safety rail. A well-prompted small model should +// resolve most tasks in 2-4 iterations. If we hit the cap, surface the error +// rather than silently returning a partial result. + +const MAX_ITERATIONS: usize = 10; + +async fn run_agent( + client: &reqwest::Client, + tools: &[Tool], + mut messages: Vec, +) -> Result> { + for i in 0..MAX_ITERATIONS { + println!("\n── iteration {} ──────────────────────────────", i + 1); + + let response = chat(client, &messages, tools).await?; + + // Step 2: push the full assistant turn before adding tool results. + // Backends that validate conversation structure will reject the request + // if tool_result messages appear without a preceding assistant message + // that contains the matching tool_call ids. + messages.push(ChatMessage::assistant( + response.content.clone(), + response.tool_calls.clone(), + )); + + match response.tool_calls { + Some(calls) if !calls.is_empty() => { + println!(" {} tool call(s) requested", calls.len()); + for call in &calls { + let result = execute_tool(call); + println!(" → {} chars returned", result.len()); + // Each tool result is a separate message, keyed by call.id + messages.push(ChatMessage::tool_result(&call.id, &result)); + } + // Continue: LLM will see the results and decide whether to call + // more tools or produce a final answer. + } + _ => { + // No tool calls → this is the terminal response. + let answer = response.content.unwrap_or_default(); + println!("\n── done in {} iteration(s) ───────────────────", i + 1); + return Ok(answer); + } + } + } + + Err(format!( + "Agent did not converge within {MAX_ITERATIONS} iterations. \ + Consider increasing the limit or checking the system prompt." + ) + .into()) +} + +// ── Entry Point ─────────────────────────────────────────────────────────────── #[tokio::main] async fn main() -> Result<(), Box> { let client = reqwest::Client::new(); + let tools = build_tools(); - // Define the tool signature - let curl_tool = Tool { - tool_type: "function".to_string(), - function: ToolFunction { - name: "run_curl".to_string(), - description: - "Executes a curl network request command to fetch data from a public URL endpoint" - .to_string(), - parameters: json!({ - "type": "object", - "properties": { - "url": { "type": "string", "description": "The full HTTP/HTTPS URL endpoint to query" }, - "method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" } - }, - "required": ["url"] - }), - }, - }; + let messages = vec![ + ChatMessage::system( + "You are a concise system assistant. Use the provided tools to gather \ + real-time information when needed. Do not guess at live data.", + ), + ChatMessage::user( + "Check if wttr.in/Delhi is accessible and give me a one-line weather summary.", + ), + ]; - let payload = ChatCompletionRequest { - model: "local-model".to_string(), - messages: vec![ - ChatMessage { - role: "system".to_string(), - content: "You are a helpful system assistant. You have access to a run_curl tool. Use it whenever you want to real-time web information or status codes.".to_string(), - }, - ChatMessage { - role: "user".to_string(), - content: "Check if wttr.in/Delhi is accessible right now using a GET request.".to_string(), - }, - ], - temperature: 0.1, - response_format: None, // <-- Pass None here since we are using tools right now - tools: Some(vec![curl_tool]), - }; - - let http_response = client - .post("http://localhost:8080/v1/chat/completions") - .json(&payload) - .send() - .await?; - - let response: ChatCompletionResponse = http_response.json().await?; - - let message = &response - .choices - .first() - .ok_or("No choices returned")? - .message; - - if let Some(tool_calls) = &message.tool_calls { - for call in tool_calls { - if call.function.name == "run_curl" { - // Parse the arguments back into JSON - let args: serde_json::Value = serde_json::from_str(&call.function.arguments)?; - let url = args["url"].as_str().unwrap_or(""); - let method = args["method"].as_str().unwrap_or("GET"); - - println!( - "\n[Assistant Triggered Tool]: Executing curl {} to {}", - method, url - ); - - // Run the actual system shell command safely via standard libraries - let output = Command::new("curl") - .arg("-s") // Silent mode - .arg("-X") - .arg(method) - .arg(url) - .output()?; - - let result_string = String::from_utf8_lossy(&output.stdout); - - println!( - "\n[System Execution Output]:\n{}", - result_string.lines().take(5).collect::>().join("\n") - ); - println!("... (truncated)"); - } - } - } else if !message.content.is_empty() { - println!("\nAI Response: {}", message.content); + match run_agent(&client, &tools, messages).await { + Ok(answer) => println!("\n── final answer ──\n{answer}"), + Err(e) => eprintln!("\n── agent error ──\n{e}"), } Ok(())