Files
bootstrap-auth-server/src/handlers.rs

205 lines
7.5 KiB
Rust

use axum::{extract::State, http::StatusCode, Json, response::{IntoResponse, Response}};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize};
use ssh_key::PublicKey;
use std::str::FromStr;
use std::io::Write;
use crate::{decrypt_secret, AppState};
#[derive(Deserialize)]
pub struct RegisterReq {
pub hostname: String,
pub os: String,
pub public_key: String,
}
#[derive(Serialize)]
pub struct RegisterRes {
pub user_code: String,
pub challenge_nonce: String,
pub expires_in: u64,
}
pub async fn register(
State(state): State<AppState>,
Json(payload): Json<RegisterReq>,
) -> Result<Response, (StatusCode, String)> {
let _key = PublicKey::from_str(&payload.public_key)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid SSH key: {}", e)))?;
let device_id: i64 = sqlx::query_scalar(
"INSERT INTO devices (hostname, os, public_key, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) RETURNING id"
)
.bind(&payload.hostname)
.bind(&payload.os)
.bind(&payload.public_key)
.fetch_one(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let (user_code, challenge_nonce) = {
let mut rng = rand::thread_rng();
let user_code: String = (&mut rng)
.sample_iter(&Alphanumeric)
.take(6)
.map(char::from)
.collect();
let challenge_nonce: String = (&mut rng)
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
(user_code, challenge_nonce)
};
let expires_in = 300;
sqlx::query(
"INSERT INTO pending_requests (user_code, device_id, challenge_nonce, expires_at) VALUES (?, ?, ?, datetime('now', '+5 minutes'))"
)
.bind(&user_code)
.bind(device_id)
.bind(&challenge_nonce)
.execute(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(RegisterRes {
user_code,
challenge_nonce,
expires_in,
}).into_response())
}
#[derive(Deserialize)]
pub struct ApproveReq {
pub user_code: String,
pub approver_public_key_fingerprint: String,
pub signature: String,
}
pub async fn approve(
State(state): State<AppState>,
Json(payload): Json<ApproveReq>,
) -> Result<Response, (StatusCode, String)> {
let approver_devices: Vec<String> = sqlx::query_scalar("SELECT public_key FROM devices WHERE approved_at IS NOT NULL")
.fetch_all(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let mut approver_pubkey_str = None;
for pub_key in approver_devices {
if let Ok(key) = PublicKey::from_str(&pub_key) {
let fingerprint = key.fingerprint(Default::default()).to_string();
if fingerprint == payload.approver_public_key_fingerprint || pub_key.contains(&payload.approver_public_key_fingerprint) {
approver_pubkey_str = Some(pub_key);
break;
}
}
}
let approver_pubkey_str = approver_pubkey_str.ok_or((StatusCode::UNAUTHORIZED, "Approver not found".to_string()))?;
let approver_key = PublicKey::from_str(&approver_pubkey_str).unwrap();
let pending: (i64, String) = sqlx::query_as(
"SELECT d.id, d.public_key FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?"
)
.bind(&payload.user_code)
.fetch_optional(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Pending request not found".to_string()))?;
let pending_id = pending.0;
let pending_pubkey = pending.1;
let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?;
let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?;
let key_data = approver_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?;
let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap();
verifying_key.verify(pending_pubkey.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?;
sqlx::query("UPDATE devices SET approved_at = CURRENT_TIMESTAMP WHERE id = ?")
.bind(pending_id)
.execute(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json("Approved!").into_response())
}
#[derive(Deserialize)]
pub struct PollReq {
pub user_code: String,
pub signature: String,
}
#[derive(Serialize)]
pub struct PollRes {
pub encrypted_secrets: String,
}
pub async fn poll(
State(state): State<AppState>,
Json(payload): Json<PollReq>,
) -> Result<Response, (StatusCode, String)> {
let req: (String, String, Option<String>) = sqlx::query_as(
"SELECT p.challenge_nonce, d.public_key, d.approved_at FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?"
)
.bind(&payload.user_code)
.fetch_optional(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Not found".to_string()))?;
let challenge_nonce = req.0;
let pub_key_str = req.1;
let approved_at = req.2;
let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?;
let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?;
let pub_key = PublicKey::from_str(&pub_key_str).unwrap();
let key_data = pub_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?;
let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap();
verifying_key.verify(challenge_nonce.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?;
if approved_at.is_none() {
return Err((StatusCode::ACCEPTED, "Pending Approval".to_string()));
}
let secrets: Vec<(String, String)> = sqlx::query_as("SELECT key_name, encrypted_value FROM secrets")
.fetch_all(&state.pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let mut secrets_env = String::new();
for (k, v) in secrets {
let plaintext = decrypt_secret(&state.master_key, &v);
secrets_env.push_str(&format!("{}={}\n", k, plaintext));
}
let recipient: age::ssh::Recipient = pub_key_str.parse().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse age recipient: {:?}", e)))?;
let r: &dyn age::Recipient = &recipient;
let encryptor = age::Encryptor::with_recipients(vec![r].into_iter())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let mut encrypted = vec![];
{
let mut writer = encryptor.wrap_output(&mut encrypted).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
writer.write_all(secrets_env.as_bytes()).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
writer.finish().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
}
sqlx::query("DELETE FROM pending_requests WHERE user_code = ?").bind(&payload.user_code).execute(&state.pool).await.ok();
Ok(Json(PollRes {
encrypted_secrets: BASE64.encode(encrypted),
}).into_response())
}