diff --git a/Cargo.lock b/Cargo.lock index f01377c..6bac7ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -292,7 +292,7 @@ dependencies = [ "base64 0.22.1", "dotenvy", "ed25519-dalek", - "rand 0.10.1", + "rand 0.8.6", "serde", "serde_json", "sha2 0.11.0", diff --git a/Cargo.toml b/Cargo.toml index 05ff729..6ae4005 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ axum = "0.8.9" base64 = "0.22.1" dotenvy = "0.15.7" ed25519-dalek = "2.2.0" -rand = "0.10.1" +rand = "0.8" serde = "1.0.228" serde_json = "1.0.150" sha2 = "0.11.0" diff --git a/src/handlers.rs b/src/handlers.rs new file mode 100644 index 0000000..2f9b4de --- /dev/null +++ b/src/handlers.rs @@ -0,0 +1,204 @@ +use axum::{extract::State, http::StatusCode, Json, response::{IntoResponse, Response}}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use ed25519_dalek::{Signature, Verifier, VerifyingKey}; +use rand::{distributions::Alphanumeric, Rng}; +use serde::{Deserialize, Serialize}; +use ssh_key::PublicKey; +use std::str::FromStr; +use std::io::Write; + +use crate::{decrypt_secret, AppState}; + +#[derive(Deserialize)] +pub struct RegisterReq { + pub hostname: String, + pub os: String, + pub public_key: String, +} + +#[derive(Serialize)] +pub struct RegisterRes { + pub user_code: String, + pub challenge_nonce: String, + pub expires_in: u64, +} + +pub async fn register( + State(state): State, + Json(payload): Json, +) -> Result { + let _key = PublicKey::from_str(&payload.public_key) + .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid SSH key: {}", e)))?; + + let device_id: i64 = sqlx::query_scalar( + "INSERT INTO devices (hostname, os, public_key, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) RETURNING id" + ) + .bind(&payload.hostname) + .bind(&payload.os) + .bind(&payload.public_key) + .fetch_one(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let (user_code, challenge_nonce) = { + let mut rng = rand::thread_rng(); + let user_code: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(6) + .map(char::from) + .collect(); + let challenge_nonce: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); + (user_code, challenge_nonce) + }; + let expires_in = 300; + + sqlx::query( + "INSERT INTO pending_requests (user_code, device_id, challenge_nonce, expires_at) VALUES (?, ?, ?, datetime('now', '+5 minutes'))" + ) + .bind(&user_code) + .bind(device_id) + .bind(&challenge_nonce) + .execute(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(Json(RegisterRes { + user_code, + challenge_nonce, + expires_in, + }).into_response()) +} + +#[derive(Deserialize)] +pub struct ApproveReq { + pub user_code: String, + pub approver_public_key_fingerprint: String, + pub signature: String, +} + +pub async fn approve( + State(state): State, + Json(payload): Json, +) -> Result { + let approver_devices: Vec = sqlx::query_scalar("SELECT public_key FROM devices WHERE approved_at IS NOT NULL") + .fetch_all(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let mut approver_pubkey_str = None; + for pub_key in approver_devices { + if let Ok(key) = PublicKey::from_str(&pub_key) { + let fingerprint = key.fingerprint(Default::default()).to_string(); + if fingerprint == payload.approver_public_key_fingerprint || pub_key.contains(&payload.approver_public_key_fingerprint) { + approver_pubkey_str = Some(pub_key); + break; + } + } + } + + let approver_pubkey_str = approver_pubkey_str.ok_or((StatusCode::UNAUTHORIZED, "Approver not found".to_string()))?; + let approver_key = PublicKey::from_str(&approver_pubkey_str).unwrap(); + + let pending: (i64, String) = sqlx::query_as( + "SELECT d.id, d.public_key FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?" + ) + .bind(&payload.user_code) + .fetch_optional(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Pending request not found".to_string()))?; + + let pending_id = pending.0; + let pending_pubkey = pending.1; + + let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?; + let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?; + + let key_data = approver_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?; + let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap(); + + verifying_key.verify(pending_pubkey.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?; + + sqlx::query("UPDATE devices SET approved_at = CURRENT_TIMESTAMP WHERE id = ?") + .bind(pending_id) + .execute(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(Json("Approved!").into_response()) +} + +#[derive(Deserialize)] +pub struct PollReq { + pub user_code: String, + pub signature: String, +} + +#[derive(Serialize)] +pub struct PollRes { + pub encrypted_secrets: String, +} + +pub async fn poll( + State(state): State, + Json(payload): Json, +) -> Result { + let req: (String, String, Option) = sqlx::query_as( + "SELECT p.challenge_nonce, d.public_key, d.approved_at FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?" + ) + .bind(&payload.user_code) + .fetch_optional(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Not found".to_string()))?; + + let challenge_nonce = req.0; + let pub_key_str = req.1; + let approved_at = req.2; + + let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?; + let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?; + + let pub_key = PublicKey::from_str(&pub_key_str).unwrap(); + let key_data = pub_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?; + let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap(); + + verifying_key.verify(challenge_nonce.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?; + + if approved_at.is_none() { + return Err((StatusCode::ACCEPTED, "Pending Approval".to_string())); + } + + let secrets: Vec<(String, String)> = sqlx::query_as("SELECT key_name, encrypted_value FROM secrets") + .fetch_all(&state.pool) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let mut secrets_env = String::new(); + for (k, v) in secrets { + let plaintext = decrypt_secret(&state.master_key, &v); + secrets_env.push_str(&format!("{}={}\n", k, plaintext)); + } + + let recipient: age::ssh::Recipient = pub_key_str.parse().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse age recipient: {:?}", e)))?; + let r: &dyn age::Recipient = &recipient; + let encryptor = age::Encryptor::with_recipients(vec![r].into_iter()) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let mut encrypted = vec![]; + { + let mut writer = encryptor.wrap_output(&mut encrypted).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + writer.write_all(secrets_env.as_bytes()).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + writer.finish().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + } + + sqlx::query("DELETE FROM pending_requests WHERE user_code = ?").bind(&payload.user_code).execute(&state.pool).await.ok(); + + Ok(Json(PollRes { + encrypted_secrets: BASE64.encode(encrypted), + }).into_response()) +} diff --git a/src/main.rs b/src/main.rs index 0daf601..775e57c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,10 @@ +pub mod handlers; + use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key, Nonce, }; -use axum::{Router, routing::get}; +use axum::{routing::post, routing::get, Router}; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use sha2::{Digest, Sha256}; use sqlx::{sqlite::SqlitePoolOptions, SqlitePool}; @@ -11,19 +13,19 @@ use std::net::SocketAddr; use tracing_subscriber::EnvFilter; #[derive(Clone)] -struct AppState { - pool: SqlitePool, - master_key: String, +pub struct AppState { + pub pool: SqlitePool, + pub master_key: String, } -fn get_master_key(master_key: &str) -> Key { +pub fn get_master_key(master_key: &str) -> Key { let mut hasher = Sha256::new(); hasher.update(master_key.as_bytes()); let result = hasher.finalize(); *Key::::from_slice(&result) } -fn encrypt_secret(master_key: &str, plaintext: &str) -> String { +pub fn encrypt_secret(master_key: &str, plaintext: &str) -> String { let key = get_master_key(master_key); let cipher = Aes256Gcm::new(&key); let nonce = Aes256Gcm::generate_nonce(&mut OsRng); @@ -36,6 +38,16 @@ fn encrypt_secret(master_key: &str, plaintext: &str) -> String { BASE64.encode(payload) } +pub fn decrypt_secret(master_key: &str, encrypted_b64: &str) -> String { + let key = get_master_key(master_key); + let payload = BASE64.decode(encrypted_b64).expect("Invalid base64 in DB"); + let nonce = Nonce::from_slice(&payload[0..12]); + let ciphertext = &payload[12..]; + let cipher = Aes256Gcm::new(&key); + let plaintext = cipher.decrypt(nonce, ciphertext).expect("decryption failure"); + String::from_utf8(plaintext).expect("invalid utf8 in secret") +} + #[tokio::main] async fn main() { tracing_subscriber::fmt() @@ -99,6 +111,9 @@ async fn main() { let app = Router::new() .route("/health", get(|| async { "OK" })) + .route("/api/register", post(handlers::register)) + .route("/api/approve", post(handlers::approve)) + .route("/api/challenge/poll", post(handlers::poll)) .with_state(state); let addr: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap(); diff --git a/test_axum_age.rs b/test_axum_age.rs new file mode 100644 index 0000000..9d8178b --- /dev/null +++ b/test_axum_age.rs @@ -0,0 +1,22 @@ +use axum::{extract::State, http::StatusCode, Json, response::IntoResponse, routing::post, Router}; +use std::str::FromStr; + +#[derive(Clone)] +struct AppState {} + +async fn handler(State(_s): State) -> Result, (StatusCode, String)> { + Ok(Json(())) +} + +fn test_age() { + let pk = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH0b9c/A... user@host"; + let recipient: age::ssh::Recipient = pk.parse().unwrap(); + let r: &dyn age::Recipient = &recipient; + let encryptor = age::Encryptor::with_recipients(vec![r].into_iter()); + let _ = encryptor.unwrap(); +} + +fn main() { + let state = AppState {}; + let app = Router::new().route("/", post(handler)).with_state(state); +}