manta_server/server/
mod.rs1pub mod api_doc;
5pub mod auth_middleware;
6pub mod common;
7pub mod handlers;
8pub mod routes;
9
10use std::collections::HashMap;
11use std::net::SocketAddr;
12use std::sync::Arc;
13
14use axum_server::tls_rustls::RustlsConfig;
15use manta_backend_dispatcher::error::Error;
16use std::time::Duration;
17
18use crate::dispatcher::StaticBackendDispatcher;
19use crate::server::common::app_context::InfraContext;
20use crate::server::common::kafka::Kafka;
21
22pub struct SiteBackend {
26 pub backend: StaticBackendDispatcher,
28 pub shasta_base_url: String,
30 pub shasta_root_cert: Vec<u8>,
32 pub socks5_proxy: Option<String>,
34 pub vault_base_url: Option<String>,
36 pub gitea_base_url: String,
38 pub k8s_api_url: Option<String>,
40}
41
42pub struct ServerState {
49 pub sites: HashMap<String, SiteBackend>,
51 pub console_inactivity_timeout: Duration,
54 pub auditor: Option<Kafka>,
57 pub auth_rate_limit_per_minute: Option<u32>,
60 pub request_timeout: Duration,
65 pub shutdown_grace_period: Duration,
69 pub migrate_backup_root: Option<std::path::PathBuf>,
75}
76
77impl ServerState {
78 pub fn infra_context<'a>(
84 &'a self,
85 site_name: &'a str,
86 ) -> Result<InfraContext<'a>, Error> {
87 let site = self.sites.get(site_name).ok_or_else(|| {
88 Error::NotFound(format!("site '{site_name}' not found"))
89 })?;
90 Ok(InfraContext {
91 backend: &site.backend,
92 site_name,
93 shasta_base_url: &site.shasta_base_url,
94 shasta_root_cert: &site.shasta_root_cert,
95 socks5_proxy: site.socks5_proxy.as_deref(),
96 vault_base_url: site.vault_base_url.as_deref(),
97 gitea_base_url: &site.gitea_base_url,
98 k8s_api_url: site.k8s_api_url.as_deref(),
99 })
100 }
101}
102
103async fn log_requests(
104 request: axum::extract::Request,
105 next: axum::middleware::Next,
106) -> axum::response::Response {
107 let method = request.method().clone();
108 let uri = request.uri().clone();
109 let response = next.run(request).await;
110 tracing::info!("{} {} → {}", method, uri, response.status());
111 response
112}
113
114pub async fn start_server(
119 state: Arc<ServerState>,
120 listen_addr: &str,
121 port: u16,
122 cert_path: Option<&str>,
123 key_path: Option<&str>,
124) -> Result<(), Error> {
125 let shutdown_grace_period = state.shutdown_grace_period;
127
128 let app =
133 routes::build_router(state).layer(axum::middleware::from_fn(log_requests));
134
135 let addr: SocketAddr = format!("{listen_addr}:{port}")
136 .parse()
137 .map_err(|e| Error::BadRequest(format!("Invalid listen address: {e}")))?;
138
139 match (cert_path, key_path) {
140 (Some(cert), Some(key)) => {
141 let tls_config = RustlsConfig::from_pem_file(cert, key).await?;
142 let handle = axum_server::Handle::new();
143 let ready_handle = handle.clone();
144 tokio::spawn(async move {
145 ready_handle.listening().await;
146 tracing::info!(
147 "HTTPS server ready, accepting requests on https://{}",
148 addr
149 );
150 eprintln!("HTTPS server ready, accepting requests on https://{addr}");
151 });
152 install_shutdown_handler(handle.clone(), shutdown_grace_period);
153 axum_server::bind_rustls(addr, tls_config)
154 .handle(handle)
155 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
156 .await?;
157 }
158 (None, None) => {
159 let handle = axum_server::Handle::new();
160 let ready_handle = handle.clone();
161 tokio::spawn(async move {
162 ready_handle.listening().await;
163 tracing::info!(
164 "HTTP server ready, accepting requests on http://{}",
165 addr
166 );
167 eprintln!("HTTP server ready, accepting requests on http://{addr}");
168 });
169 install_shutdown_handler(handle.clone(), shutdown_grace_period);
170 axum_server::bind(addr)
171 .handle(handle)
172 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
173 .await?;
174 }
175 _ => {
176 return Err(Error::BadRequest(
177 "--cert and --key must be provided together".to_string(),
178 ));
179 }
180 }
181
182 Ok(())
183}
184
185fn install_shutdown_handler(
195 handle: axum_server::Handle<SocketAddr>,
196 grace_period: Duration,
197) {
198 tokio::spawn(async move {
199 let mut sigterm = match tokio::signal::unix::signal(
200 tokio::signal::unix::SignalKind::terminate(),
201 ) {
202 Ok(s) => s,
203 Err(e) => {
204 tracing::warn!(
205 "failed to install SIGTERM handler; falling back to Ctrl+C only: {e}"
206 );
207 let _ = tokio::signal::ctrl_c().await;
208 handle.graceful_shutdown(Some(grace_period));
209 return;
210 }
211 };
212 let grace_secs = grace_period.as_secs();
213 tokio::select! {
214 _ = sigterm.recv() => {
215 tracing::info!("SIGTERM received; draining for up to {grace_secs}s");
216 }
217 _ = tokio::signal::ctrl_c() => {
218 tracing::info!("Ctrl+C received; draining for up to {grace_secs}s");
219 }
220 }
221 handle.graceful_shutdown(Some(grace_period));
222 });
223}
224
225#[cfg(test)]
226mod timeout_layer_tests {
227 use std::time::Duration;
239
240 use axum::{
241 Router,
242 body::Body,
243 http::{Request, StatusCode},
244 routing::get,
245 };
246 use tower::ServiceExt as _;
247 use tower_http::timeout::TimeoutLayer;
248
249 fn get_req(uri: &str) -> Request<Body> {
250 Request::builder()
251 .method("GET")
252 .uri(uri)
253 .body(Body::empty())
254 .unwrap()
255 }
256
257 async fn sleep_handler(delay: Duration) -> &'static str {
260 tokio::time::sleep(delay).await;
261 "ok"
262 }
263
264 #[tokio::test]
265 async fn global_timeout_returns_408_when_handler_exceeds_limit() {
266 let router = Router::new()
267 .route(
268 "/slow",
269 get(|| async { sleep_handler(Duration::from_millis(400)).await }),
270 )
271 .layer(TimeoutLayer::with_status_code(
272 StatusCode::REQUEST_TIMEOUT,
273 Duration::from_millis(50),
274 ));
275
276 let resp = router.oneshot(get_req("/slow")).await.unwrap();
277 assert_eq!(resp.status(), StatusCode::REQUEST_TIMEOUT);
278 }
279
280 #[tokio::test]
281 async fn fast_handler_finishes_before_timeout_fires() {
282 let router = Router::new()
283 .route(
284 "/fast",
285 get(|| async { sleep_handler(Duration::from_millis(10)).await }),
286 )
287 .layer(TimeoutLayer::with_status_code(
288 StatusCode::REQUEST_TIMEOUT,
289 Duration::from_secs(5),
290 ));
291
292 let resp = router.oneshot(get_req("/fast")).await.unwrap();
293 assert_eq!(resp.status(), StatusCode::OK);
294 }
295}