init commit
This commit is contained in:
170
src/main.rs
Normal file
170
src/main.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct JsonSchemaObject {
|
||||
name: String,
|
||||
schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
format_type: String, // "json_object" or "json_schema"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
json_schema: Option<JsonSchemaObject>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] // <-- Add this line
|
||||
response_format: Option<ResponseFormat>, // <-- Wrap in Option
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<Tool>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ToolCall {
|
||||
id: String,
|
||||
#[serde(rename = "type")]
|
||||
call_type: String, // Always "function"
|
||||
function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct FunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolFunction {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // Always "function"
|
||||
function: ToolFunction,
|
||||
}
|
||||
|
||||
use serde_json::json;
|
||||
use std::process::Command;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Define the tool signature
|
||||
let curl_tool = Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: "run_curl".to_string(),
|
||||
description:
|
||||
"Executes a curl network request command to fetch data from a public URL endpoint"
|
||||
.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string", "description": "The full HTTP/HTTPS URL endpoint to query" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST"], "default": "GET" }
|
||||
},
|
||||
"required": ["url"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
let payload = ChatCompletionRequest {
|
||||
model: "local-model".to_string(),
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful system assistant. You have access to a run_curl tool. Use it whenever you want to real-time web information or status codes.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Check if wttr.in/Delhi is accessible right now using a GET request.".to_string(),
|
||||
},
|
||||
],
|
||||
temperature: 0.1,
|
||||
response_format: None, // <-- Pass None here since we are using tools right now
|
||||
tools: Some(vec![curl_tool]),
|
||||
};
|
||||
|
||||
let http_response = client
|
||||
.post("http://localhost:8080/v1/chat/completions")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let response: ChatCompletionResponse = http_response.json().await?;
|
||||
|
||||
let message = &response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or("No choices returned")?
|
||||
.message;
|
||||
|
||||
if let Some(tool_calls) = &message.tool_calls {
|
||||
for call in tool_calls {
|
||||
if call.function.name == "run_curl" {
|
||||
// Parse the arguments back into JSON
|
||||
let args: serde_json::Value = serde_json::from_str(&call.function.arguments)?;
|
||||
let url = args["url"].as_str().unwrap_or("");
|
||||
let method = args["method"].as_str().unwrap_or("GET");
|
||||
|
||||
println!(
|
||||
"\n[Assistant Triggered Tool]: Executing curl {} to {}",
|
||||
method, url
|
||||
);
|
||||
|
||||
// Run the actual system shell command safely via standard libraries
|
||||
let output = Command::new("curl")
|
||||
.arg("-s") // Silent mode
|
||||
.arg("-X")
|
||||
.arg(method)
|
||||
.arg(url)
|
||||
.output()?;
|
||||
|
||||
let result_string = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
println!(
|
||||
"\n[System Execution Output]:\n{}",
|
||||
result_string.lines().take(5).collect::<Vec<_>>().join("\n")
|
||||
);
|
||||
println!("... (truncated)");
|
||||
}
|
||||
}
|
||||
} else if !message.content.is_empty() {
|
||||
println!("\nAI Response: {}", message.content);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user