manta_server/server/
auth_middleware.rs1use std::collections::HashMap;
19use std::net::{IpAddr, SocketAddr};
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, Instant};
22
23use axum::{
24 Json,
25 extract::{ConnectInfo, Request, State},
26 http::StatusCode,
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30
31use super::ServerState;
32use super::handlers::ErrorResponse;
33
34struct WindowState {
36 window_start: Instant,
37 count: u32,
38}
39
40#[derive(Default)]
44pub struct AuthRateLimiter {
45 windows: Mutex<HashMap<IpAddr, WindowState>>,
46}
47
48impl AuthRateLimiter {
49 pub fn new() -> Arc<Self> {
52 Arc::new(Self::default())
53 }
54
55 fn check(&self, ip: IpAddr, limit: u32) -> bool {
59 self.check_at(ip, limit, Instant::now())
60 }
61
62 fn check_at(&self, ip: IpAddr, limit: u32, now: Instant) -> bool {
66 let window = Duration::from_secs(60);
67 let mut windows = self.windows.lock().expect("rate limiter mutex poisoned");
68
69 windows
71 .retain(|_, state| now.duration_since(state.window_start) < window * 2);
72
73 let entry = windows.entry(ip).or_insert(WindowState {
74 window_start: now,
75 count: 0,
76 });
77
78 if now.duration_since(entry.window_start) >= window {
79 entry.window_start = now;
80 entry.count = 0;
81 }
82
83 if entry.count >= limit {
84 return false;
85 }
86 entry.count += 1;
87 true
88 }
89}
90
91pub async fn rate_limit(
95 State(state): State<Arc<ServerState>>,
96 ConnectInfo(peer): ConnectInfo<SocketAddr>,
97 limiter: axum::extract::Extension<Arc<AuthRateLimiter>>,
98 request: Request,
99 next: Next,
100) -> Response {
101 let Some(limit) = state.auth_rate_limit_per_minute else {
102 return next.run(request).await;
103 };
104 if !limiter.check(peer.ip(), limit) {
105 tracing::warn!(
106 "auth: rate limit exceeded for source {} (limit={}/min)",
107 peer.ip(),
108 limit
109 );
110 return (
111 StatusCode::TOO_MANY_REQUESTS,
112 Json(ErrorResponse {
113 error: "rate limit exceeded".to_string(),
114 }),
115 )
116 .into_response();
117 }
118 next.run(request).await
119}
120
121pub async fn strip_body_for_logs(request: Request, next: Next) -> Response {
126 let span = tracing::info_span!("auth_request", body = "<redacted>");
127 let _enter = span.enter();
128 next.run(request).await
129}
130
131#[cfg(test)]
132mod tests;