From 999e0ae36e049309d629c67394aa99ec472760f2 Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Thu, 25 Jun 2026 16:32:08 +0530 Subject: [PATCH] test: Added tests for keygen, mock secret prov, registration --- src/main.rs | 19 ++++++--- src/tests.rs | 100 +++++++++++++++++++++++++++++++++++++++++++++++ test_axum_age.rs | 22 ----------- 3 files changed, 113 insertions(+), 28 deletions(-) create mode 100644 src/tests.rs delete mode 100644 test_axum_age.rs diff --git a/src/main.rs b/src/main.rs index 775e57c..9fb40b7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -109,12 +109,7 @@ async fn main() { master_key, }; - let app = Router::new() - .route("/health", get(|| async { "OK" })) - .route("/api/register", post(handlers::register)) - .route("/api/approve", post(handlers::approve)) - .route("/api/challenge/poll", post(handlers::poll)) - .with_state(state); + let app = create_router(state); let addr: SocketAddr = format!("0.0.0.0:{}", port).parse().unwrap(); tracing::info!("Listening on {}", addr); @@ -122,3 +117,15 @@ async fn main() { let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } + +pub fn create_router(state: AppState) -> Router { + Router::new() + .route("/health", get(|| async { "OK" })) + .route("/api/register", post(handlers::register)) + .route("/api/approve", post(handlers::approve)) + .route("/api/challenge/poll", post(handlers::poll)) + .with_state(state) +} + +#[cfg(test)] +mod tests; diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..4c86de9 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,100 @@ +#[cfg(test)] +mod test { + use crate::handlers::{self, ApproveReq, PollReq, RegisterReq}; + use crate::AppState; + use axum::body::to_bytes; + use axum::extract::{Json, State}; + use axum::response::Response; + use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; + use ed25519_dalek::{Signer, SigningKey}; + use rand::rngs::OsRng; + use rand::RngCore; + use sqlx::sqlite::SqlitePoolOptions; + use ssh_key::public::{Ed25519PublicKey, KeyData}; + use ssh_key::PublicKey; + use std::str::FromStr; + + async fn setup_state() -> AppState { + let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await.unwrap(); + sqlx::migrate!("./migrations").run(&pool).await.unwrap(); + AppState { + pool, + master_key: "test_master_key".to_string(), + } + } + + fn generate_keypair() -> (SigningKey, String) { + let mut bytes = [0u8; 32]; + OsRng.fill_bytes(&mut bytes); + let sk = SigningKey::from_bytes(&bytes); + let vk = sk.verifying_key(); + let ed_pk = Ed25519PublicKey(*vk.as_bytes()); + let mut ssh_pk = PublicKey::from(KeyData::Ed25519(ed_pk)); + ssh_pk.set_comment("test@test"); + (sk, ssh_pk.to_string()) + } + + async fn get_json_body(res: Response) -> serde_json::Value { + let body_bytes = to_bytes(res.into_body(), usize::MAX).await.unwrap(); + serde_json::from_slice(&body_bytes).unwrap() + } + + #[tokio::test] + async fn test_full_auth_flow() { + let state = setup_state().await; + + sqlx::query("INSERT INTO secrets (key_name, encrypted_value) VALUES ('TEST_SECRET', ?)") + .bind(crate::encrypt_secret(&state.master_key, "super_secret_value")) + .execute(&state.pool) + .await + .unwrap(); + + let (client_sk, client_pk_str) = generate_keypair(); + let (admin_sk, admin_pk_str) = generate_keypair(); + + sqlx::query("INSERT INTO devices (hostname, os, public_key, approved_at, created_at) VALUES ('admin', 'linux', ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)") + .bind(&admin_pk_str) + .execute(&state.pool) + .await + .unwrap(); + + let reg_req = RegisterReq { + hostname: "client-host".to_string(), + os: "linux".to_string(), + public_key: client_pk_str.clone(), + }; + let reg_res_raw = handlers::register(State(state.clone()), Json(reg_req)).await.unwrap(); + let reg_res = get_json_body(reg_res_raw).await; + let user_code = reg_res["user_code"].as_str().unwrap().to_string(); + let challenge_nonce = reg_res["challenge_nonce"].as_str().unwrap().to_string(); + + let poll_sig1 = client_sk.sign(challenge_nonce.as_bytes()); + let poll_req1 = PollReq { + user_code: user_code.clone(), + signature: BASE64.encode(poll_sig1.to_bytes()), + }; + let poll_err = handlers::poll(State(state.clone()), Json(poll_req1)).await.unwrap_err(); + assert_eq!(poll_err.0, axum::http::StatusCode::ACCEPTED); + + let admin_sig = admin_sk.sign(client_pk_str.as_bytes()); + let admin_pk = PublicKey::from_str(&admin_pk_str).unwrap(); + let approve_req = ApproveReq { + user_code: user_code.clone(), + approver_public_key_fingerprint: admin_pk.fingerprint(Default::default()).to_string(), + signature: BASE64.encode(admin_sig.to_bytes()), + }; + let app_res = handlers::approve(State(state.clone()), Json(approve_req)).await.unwrap(); + assert_eq!(app_res.status(), axum::http::StatusCode::OK); + + let poll_sig2 = client_sk.sign(challenge_nonce.as_bytes()); + let poll_req2 = PollReq { + user_code: user_code.clone(), + signature: BASE64.encode(poll_sig2.to_bytes()), + }; + let poll_res_raw = handlers::poll(State(state.clone()), Json(poll_req2)).await.unwrap(); + let poll_res = get_json_body(poll_res_raw).await; + + let enc_secrets = poll_res["encrypted_secrets"].as_str().unwrap(); + assert!(!enc_secrets.is_empty()); + } +} diff --git a/test_axum_age.rs b/test_axum_age.rs deleted file mode 100644 index 9d8178b..0000000 --- a/test_axum_age.rs +++ /dev/null @@ -1,22 +0,0 @@ -use axum::{extract::State, http::StatusCode, Json, response::IntoResponse, routing::post, Router}; -use std::str::FromStr; - -#[derive(Clone)] -struct AppState {} - -async fn handler(State(_s): State) -> Result, (StatusCode, String)> { - Ok(Json(())) -} - -fn test_age() { - let pk = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIH0b9c/A... user@host"; - let recipient: age::ssh::Recipient = pk.parse().unwrap(); - let r: &dyn age::Recipient = &recipient; - let encryptor = age::Encryptor::with_recipients(vec![r].into_iter()); - let _ = encryptor.unwrap(); -} - -fn main() { - let state = AppState {}; - let app = Router::new().route("/", post(handler)).with_state(state); -}