feat: Create API endpoints for register, approve and poll
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -292,7 +292,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"dotenvy",
|
||||
"ed25519-dalek",
|
||||
"rand 0.10.1",
|
||||
"rand 0.8.6",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2 0.11.0",
|
||||
|
||||
@@ -10,7 +10,7 @@ axum = "0.8.9"
|
||||
base64 = "0.22.1"
|
||||
dotenvy = "0.15.7"
|
||||
ed25519-dalek = "2.2.0"
|
||||
rand = "0.10.1"
|
||||
rand = "0.8"
|
||||
serde = "1.0.228"
|
||||
serde_json = "1.0.150"
|
||||
sha2 = "0.11.0"
|
||||
|
||||
204
src/handlers.rs
Normal file
204
src/handlers.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use axum::{extract::State, http::StatusCode, Json, response::{IntoResponse, Response}};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ssh_key::PublicKey;
|
||||
use std::str::FromStr;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::{decrypt_secret, AppState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RegisterReq {
|
||||
pub hostname: String,
|
||||
pub os: String,
|
||||
pub public_key: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct RegisterRes {
|
||||
pub user_code: String,
|
||||
pub challenge_nonce: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<RegisterReq>,
|
||||
) -> Result<Response, (StatusCode, String)> {
|
||||
let _key = PublicKey::from_str(&payload.public_key)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid SSH key: {}", e)))?;
|
||||
|
||||
let device_id: i64 = sqlx::query_scalar(
|
||||
"INSERT INTO devices (hostname, os, public_key, created_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) RETURNING id"
|
||||
)
|
||||
.bind(&payload.hostname)
|
||||
.bind(&payload.os)
|
||||
.bind(&payload.public_key)
|
||||
.fetch_one(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let (user_code, challenge_nonce) = {
|
||||
let mut rng = rand::thread_rng();
|
||||
let user_code: String = (&mut rng)
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(6)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let challenge_nonce: String = (&mut rng)
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
(user_code, challenge_nonce)
|
||||
};
|
||||
let expires_in = 300;
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO pending_requests (user_code, device_id, challenge_nonce, expires_at) VALUES (?, ?, ?, datetime('now', '+5 minutes'))"
|
||||
)
|
||||
.bind(&user_code)
|
||||
.bind(device_id)
|
||||
.bind(&challenge_nonce)
|
||||
.execute(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(RegisterRes {
|
||||
user_code,
|
||||
challenge_nonce,
|
||||
expires_in,
|
||||
}).into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ApproveReq {
|
||||
pub user_code: String,
|
||||
pub approver_public_key_fingerprint: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
pub async fn approve(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<ApproveReq>,
|
||||
) -> Result<Response, (StatusCode, String)> {
|
||||
let approver_devices: Vec<String> = sqlx::query_scalar("SELECT public_key FROM devices WHERE approved_at IS NOT NULL")
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let mut approver_pubkey_str = None;
|
||||
for pub_key in approver_devices {
|
||||
if let Ok(key) = PublicKey::from_str(&pub_key) {
|
||||
let fingerprint = key.fingerprint(Default::default()).to_string();
|
||||
if fingerprint == payload.approver_public_key_fingerprint || pub_key.contains(&payload.approver_public_key_fingerprint) {
|
||||
approver_pubkey_str = Some(pub_key);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let approver_pubkey_str = approver_pubkey_str.ok_or((StatusCode::UNAUTHORIZED, "Approver not found".to_string()))?;
|
||||
let approver_key = PublicKey::from_str(&approver_pubkey_str).unwrap();
|
||||
|
||||
let pending: (i64, String) = sqlx::query_as(
|
||||
"SELECT d.id, d.public_key FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?"
|
||||
)
|
||||
.bind(&payload.user_code)
|
||||
.fetch_optional(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Pending request not found".to_string()))?;
|
||||
|
||||
let pending_id = pending.0;
|
||||
let pending_pubkey = pending.1;
|
||||
|
||||
let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?;
|
||||
let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?;
|
||||
|
||||
let key_data = approver_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?;
|
||||
let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap();
|
||||
|
||||
verifying_key.verify(pending_pubkey.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?;
|
||||
|
||||
sqlx::query("UPDATE devices SET approved_at = CURRENT_TIMESTAMP WHERE id = ?")
|
||||
.bind(pending_id)
|
||||
.execute(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json("Approved!").into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PollReq {
|
||||
pub user_code: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PollRes {
|
||||
pub encrypted_secrets: String,
|
||||
}
|
||||
|
||||
pub async fn poll(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<PollReq>,
|
||||
) -> Result<Response, (StatusCode, String)> {
|
||||
let req: (String, String, Option<String>) = sqlx::query_as(
|
||||
"SELECT p.challenge_nonce, d.public_key, d.approved_at FROM pending_requests p JOIN devices d ON p.device_id = d.id WHERE p.user_code = ?"
|
||||
)
|
||||
.bind(&payload.user_code)
|
||||
.fetch_optional(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Not found".to_string()))?;
|
||||
|
||||
let challenge_nonce = req.0;
|
||||
let pub_key_str = req.1;
|
||||
let approved_at = req.2;
|
||||
|
||||
let sig_bytes = BASE64.decode(&payload.signature).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid b64".to_string()))?;
|
||||
let sig = Signature::from_slice(&sig_bytes).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid sig len".to_string()))?;
|
||||
|
||||
let pub_key = PublicKey::from_str(&pub_key_str).unwrap();
|
||||
let key_data = pub_key.key_data().ed25519().ok_or((StatusCode::BAD_REQUEST, "Not Ed25519".to_string()))?;
|
||||
let verifying_key = VerifyingKey::from_bytes(key_data.as_ref()).unwrap();
|
||||
|
||||
verifying_key.verify(challenge_nonce.as_bytes(), &sig).map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid sig".to_string()))?;
|
||||
|
||||
if approved_at.is_none() {
|
||||
return Err((StatusCode::ACCEPTED, "Pending Approval".to_string()));
|
||||
}
|
||||
|
||||
let secrets: Vec<(String, String)> = sqlx::query_as("SELECT key_name, encrypted_value FROM secrets")
|
||||
.fetch_all(&state.pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let mut secrets_env = String::new();
|
||||
for (k, v) in secrets {
|
||||
let plaintext = decrypt_secret(&state.master_key, &v);
|
||||
secrets_env.push_str(&format!("{}={}\n", k, plaintext));
|
||||
}
|
||||
|
||||
let recipient: age::ssh::Recipient = pub_key_str.parse().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse age recipient: {:?}", e)))?;
|
||||
let r: &dyn age::Recipient = &recipient;
|
||||
let encryptor = age::Encryptor::with_recipients(vec![r].into_iter())
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let mut encrypted = vec![];
|
||||
{
|
||||
let mut writer = encryptor.wrap_output(&mut encrypted).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
writer.write_all(secrets_env.as_bytes()).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
writer.finish().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM pending_requests WHERE user_code = ?").bind(&payload.user_code).execute(&state.pool).await.ok();
|
||||
|
||||
Ok(Json(PollRes {
|
||||
encrypted_secrets: BASE64.encode(encrypted),
|
||||
}).into_response())
|
||||
}
|
||||
27
src/main.rs
27
src/main.rs
@@ -1,8 +1,10 @@
|
||||
pub mod handlers;
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Key, Nonce,
|
||||
};
|
||||
use axum::{Router, routing::get};
|
||||
use axum::{routing::post, routing::get, Router};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
|
||||
@@ -11,19 +13,19 @@ use std::net::SocketAddr;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
pool: SqlitePool,
|
||||
master_key: String,
|
||||
pub struct AppState {
|
||||
pub pool: SqlitePool,
|
||||
pub master_key: String,
|
||||
}
|
||||
|
||||
fn get_master_key(master_key: &str) -> Key<Aes256Gcm> {
|
||||
pub fn get_master_key(master_key: &str) -> Key<Aes256Gcm> {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(master_key.as_bytes());
|
||||
let result = hasher.finalize();
|
||||
*Key::<Aes256Gcm>::from_slice(&result)
|
||||
}
|
||||
|
||||
fn encrypt_secret(master_key: &str, plaintext: &str) -> String {
|
||||
pub fn encrypt_secret(master_key: &str, plaintext: &str) -> String {
|
||||
let key = get_master_key(master_key);
|
||||
let cipher = Aes256Gcm::new(&key);
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
@@ -36,6 +38,16 @@ fn encrypt_secret(master_key: &str, plaintext: &str) -> String {
|
||||
BASE64.encode(payload)
|
||||
}
|
||||
|
||||
pub fn decrypt_secret(master_key: &str, encrypted_b64: &str) -> String {
|
||||
let key = get_master_key(master_key);
|
||||
let payload = BASE64.decode(encrypted_b64).expect("Invalid base64 in DB");
|
||||
let nonce = Nonce::from_slice(&payload[0..12]);
|
||||
let ciphertext = &payload[12..];
|
||||
let cipher = Aes256Gcm::new(&key);
|
||||
let plaintext = cipher.decrypt(nonce, ciphertext).expect("decryption failure");
|
||||
String::from_utf8(plaintext).expect("invalid utf8 in secret")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
@@ -99,6 +111,9 @@ async fn main() {
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(|| async { "OK" }))
|
||||
.route("/api/register", post(handlers::register))
|
||||
.route("/api/approve", post(handlers::approve))
|
||||
.route("/api/challenge/poll", post(handlers::poll))
|
||||
.with_state(state);
|
||||
|
||||
let addr: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap();
|
||||
|
||||
22
test_axum_age.rs
Normal file
22
test_axum_age.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use axum::{extract::State, http::StatusCode, Json, response::IntoResponse, routing::post, Router};
|
||||
use std::str::FromStr;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {}
|
||||
|
||||
async fn handler(State(_s): State<AppState>) -> Result<Json<()>, (StatusCode, String)> {
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
||||
fn test_age() {
|
||||
let pk = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH0b9c/A... user@host";
|
||||
let recipient: age::ssh::Recipient = pk.parse().unwrap();
|
||||
let r: &dyn age::Recipient = &recipient;
|
||||
let encryptor = age::Encryptor::with_recipients(vec![r].into_iter());
|
||||
let _ = encryptor.unwrap();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let state = AppState {};
|
||||
let app = Router::new().route("/", post(handler)).with_state(state);
|
||||
}
|
||||
Reference in New Issue
Block a user