use axum::{extract::State, http::StatusCode, Json, response::{IntoResponse, Response}}; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use ed25519_dalek::{Signature, Verifier, VerifyingKey}; use rand::{distributions::Alphanumeric, Rng}; use serde::{Deserialize, Serialize}; use ssh_key::PublicKey; use std::str::FromStr; use std::io::Write; use crate::{decrypt_secret, AppState}; #[derive(Deserialize)] pub struct RegisterReq { pub hostname: String, pub os: String, pub public_key: String, } #[derive(Serialize)] pub struct RegisterRes { pub user_code: String, pub challenge_nonce: String, pub expires_in: u64, } pub async fn register( State(state): State, Json(payload): Json, ) -> Result { let _key = PublicKey::from_str(&payload.public_key) .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid SSH key: {}", e)))?; let device_id: i64 = sqlx::query_scalar( "INSERT INTO devices (hostname, os, public_key, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) RETURNING id" ) .bind(&payload.hostname) .bind(&payload.os) .bind(&payload.public_key) .fetch_one(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let (user_code, challenge_nonce) = { let mut rng = rand::thread_rng(); let user_code: String = (&mut rng) .sample_iter(&Alphanumeric) .take(6) .map(char::from) .collect(); let challenge_nonce: String = (&mut rng) .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect(); (user_code, challenge_nonce) }; let expires_in = 300; sqlx::query( "INSERT INTO pending_requests (user_code, device_id, challenge_nonce, expires_at) VALUES (?, ?, ?, datetime('now', '+5 minutes'))" ) .bind(&user_code) .bind(device_id) .bind(&challenge_nonce) .execute(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json(RegisterRes { user_code, challenge_nonce, expires_in, }).into_response()) } #[derive(Deserialize)] pub struct ApproveReq { pub user_code: String, pub approver_public_key_fingerprint: String, pub signature: String, } pub async fn approve( State(state): State, Json(payload): Json, ) -> Result { let approver_devices: Vec = sqlx::query_scalar("SELECT public_key FROM devices WHERE approved_at IS NOT NULL") .fetch_all(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let mut approver_pubkey_str = None; for pub_key in approver_devices { if let Ok(key) = PublicKey::from_str(&pub_key) { let fingerprint = key.fingerprint(Default::default()).to_string(); if fingerprint == payload.approver_public_key_fingerprint || pub_key.contains(&payload.approver_public_key_fingerprint) { approver_pubkey_str = Some(pub_key); break; } } } let approver_pubkey_str = approver_pubkey_str.ok_or((StatusCode::UNAUTHORIZED, "Approver not found".to_string()))?; let approver_key = PublicKey::from_str(&approver_pubkey_str).unwrap(); let pending: (i64, String) = sqlx::query_as( "SELECT d.id, d.public_key FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?" ) .bind(&payload.user_code) .fetch_optional(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::NOT_FOUND, "Pending request not found".to_string()))?; let pending_id = pending.0; let pending_pubkey = pending.1; let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?; let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?; let key_data = approver_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?; let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap(); verifying_key.verify(pending_pubkey.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?; sqlx::query("UPDATE devices SET approved_at = CURRENT_TIMESTAMP WHERE id = ?") .bind(pending_id) .execute(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json("Approved!").into_response()) } #[derive(Deserialize)] pub struct PollReq { pub user_code: String, pub signature: String, } #[derive(Serialize)] pub struct PollRes { pub encrypted_secrets: String, } pub async fn poll( State(state): State, Json(payload): Json, ) -> Result { let req: (String, String, Option) = sqlx::query_as( "SELECT p.challenge_nonce, d.public_key, d.approved_at FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?" ) .bind(&payload.user_code) .fetch_optional(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::NOT_FOUND, "Not found".to_string()))?; let challenge_nonce = req.0; let pub_key_str = req.1; let approved_at = req.2; let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?; let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?; let pub_key = PublicKey::from_str(&pub_key_str).unwrap(); let key_data = pub_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?; let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap(); verifying_key.verify(challenge_nonce.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?; if approved_at.is_none() { return Err((StatusCode::ACCEPTED, "Pending Approval".to_string())); } let secrets: Vec<(String, String)> = sqlx::query_as("SELECT key_name, encrypted_value FROM secrets") .fetch_all(&state.pool) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let mut secrets_env = String::new(); for (k, v) in secrets { let plaintext = decrypt_secret(&state.master_key, &v); secrets_env.push_str(&format!("{}={}\n", k, plaintext)); } let recipient: age::ssh::Recipient = pub_key_str.parse().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse age recipient: {:?}", e)))?; let r: &dyn age::Recipient = &recipient; let encryptor = age::Encryptor::with_recipients(vec![r].into_iter()) .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let mut encrypted = vec![]; { let mut writer = encryptor.wrap_output(&mut encrypted).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; writer.write_all(secrets_env.as_bytes()).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; writer.finish().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; } sqlx::query("DELETE FROM pending_requests WHERE user_code = ?").bind(&payload.user_code).execute(&state.pool).await.ok(); Ok(Json(PollRes { encrypted_secrets: BASE64.encode(encrypted), }).into_response()) }