diff --git a/src/auth.rs b/src/auth.rs index 0fb89d0..6414ddf 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,4 +1,8 @@ -use axum::{http::HeaderMap, http::StatusCode, response::IntoResponse, response::Json}; +use axum::body::Body; +use axum::{ + http::HeaderMap, http::Request, http::StatusCode, middleware::Next, response::IntoResponse, + response::Json, response::Response, +}; use base64::{Engine, engine::general_purpose}; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; use pam::Client; @@ -114,6 +118,14 @@ pub fn verify_system_credentials(username: &str, password: &str) -> bool { client.authenticate().is_ok() } +pub async fn require_auth(headers: HeaderMap, request: Request
, next: Next) -> Response { + if verify_token(&headers) { + next.run(request).await + } else { + (StatusCode::UNAUTHORIZED, "Unauthorized").into_response() + } +} + // POST /auth/login pub async fn post_login(headers: HeaderMap) -> impl IntoResponse { let (username, password) = match decode_basic_auth(&headers) { diff --git a/src/main.rs b/src/main.rs index 8913c00..d81c9c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use axum::middleware; use axum::response::Redirect; use axum::{Router, routing::get, routing::post}; @@ -8,10 +9,8 @@ mod routes; #[tokio::main] async fn main() { - let app = Router::new() - .route("/", get(|| async { Redirect::permanent("/stats") })) + let protected = Router::new() .route("/stats", get(routes::stats::get_stats)) - .route("/auth/login", post(auth::post_login)) .route( "/services/{service}/restart", post(routes::services::restart_service), @@ -28,7 +27,13 @@ async fn main() { "/services/{service}/logs", get(routes::services::service_logs), ) - .route("/system/reboot", post(routes::system::system_reboot)); + .route("/system/reboot", post(routes::system::system_reboot)) + .route_layer(middleware::from_fn(auth::require_auth)); + + let app = Router::new() + .route("/", get(|| async { Redirect::permanent("/stats") })) + .route("/auth/login", post(auth::post_login)) + .merge(protected); let listener = tokio::net::TcpListener::bind("127.0.0.1:3001") .await diff --git a/src/routes/services.rs b/src/routes/services.rs index 06e140f..0675ece 100644 --- a/src/routes/services.rs +++ b/src/routes/services.rs @@ -1,6 +1,5 @@ use axum::{ - extract::Path, http::HeaderMap, http::StatusCode, response::IntoResponse, - response::Json, + extract::Path, http::HeaderMap, http::StatusCode, response::IntoResponse, response::Json, }; use tokio::process::Command; use zbus::Connection; @@ -46,9 +45,6 @@ async fn systemd_action(action: &str, service: &str) -> (StatusCode, Json