1use axum::{
4 Json,
5 extract::{
6 Path, Query,
7 ws::{Message, WebSocket, WebSocketUpgrade},
8 },
9 http::StatusCode,
10 response::IntoResponse,
11};
12use futures::StreamExt;
13use manta_backend_dispatcher::{
14 interfaces::console::ConsoleTrait,
15 types::{K8sAuth, K8sDetails},
16};
17use serde::Deserialize;
18use tokio::io::AsyncWriteExt;
19use utoipa::IntoParams;
20
21use super::{
22 ErrorResponse, RequestCtx, SiteHeader, require_k8s_url, require_vault,
23 to_handler_error,
24};
25use crate::service;
26
27#[derive(Deserialize, IntoParams)]
33pub struct ConsoleQuery {
34 #[serde(default = "default_cols")]
36 pub cols: u16,
37 #[serde(default = "default_rows")]
39 pub rows: u16,
40}
41
42fn default_cols() -> u16 {
43 80
44}
45fn default_rows() -> u16 {
46 24
47}
48
49#[utoipa::path(get, path = "/nodes/{xname}/console", tag = "console",
51 params(("xname" = String, Path, description = "Node xname"), ConsoleQuery, SiteHeader),
52 security(("bearerAuth" = [])),
53 responses(
54 (status = 101, description = "WebSocket upgrade"),
55 (status = 401, description = "Unauthorized", body = ErrorResponse),
56 (status = 500, description = "Internal error", body = ErrorResponse),
57 (status = 501, description = "Vault or k8s not configured", body = ErrorResponse),
58 )
59)]
60#[tracing::instrument(skip_all, fields(xname = %xname))]
61pub async fn console_node_ws(
62 ctx: RequestCtx,
63 Path(xname): Path<String>,
64 Query(q): Query<ConsoleQuery>,
65 ws: WebSocketUpgrade,
66) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
67 let (k8s_api_url, vault_base_url, timeout) = {
69 let infra = ctx.infra();
70 let k = require_k8s_url(infra.k8s_api_url)?.to_string();
71 let v = require_vault(infra.vault_base_url)?.to_string();
72 (k, v, ctx.state.console_inactivity_timeout)
73 };
74
75 let k8s = K8sDetails {
76 api_url: k8s_api_url,
77 authentication: K8sAuth::Vault {
78 base_url: vault_base_url,
79 },
80 };
81
82 let RequestCtx {
85 state,
86 token,
87 site_name,
88 } = ctx;
89
90 Ok(ws.on_upgrade(move |socket| async move {
91 tracing::info!("WebSocket console opened for node {xname}");
92 if let Some(site) = state.sites.get(&site_name) {
93 match site
94 .backend
95 .attach_to_node_console(
96 &token, &site_name, &xname, q.cols, q.rows, &k8s,
97 )
98 .await
99 {
100 Ok((console_in, console_out)) => {
101 run_console_bridge(socket, console_in, console_out, timeout).await;
102 tracing::info!("WebSocket console closed for node {xname}");
103 }
104 Err(e) => {
105 tracing::error!("Failed to attach to node console {xname}: {e:#}");
106 }
107 }
108 }
109 }))
110}
111
112#[utoipa::path(get, path = "/sessions/{name}/console", tag = "console",
118 params(("name" = String, Path, description = "Session name"), ConsoleQuery, SiteHeader),
119 security(("bearerAuth" = [])),
120 responses(
121 (status = 101, description = "WebSocket upgrade"),
122 (status = 401, description = "Unauthorized", body = ErrorResponse),
123 (status = 500, description = "Internal error", body = ErrorResponse),
124 (status = 501, description = "Vault or k8s not configured", body = ErrorResponse),
125 )
126)]
127#[tracing::instrument(skip_all, fields(session = %name))]
128pub async fn console_session_ws(
129 ctx: RequestCtx,
130 Path(name): Path<String>,
131 Query(q): Query<ConsoleQuery>,
132 ws: WebSocketUpgrade,
133) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
134 let (k8s_api_url, vault_base_url, timeout) = {
137 let infra = ctx.infra();
138 let k = require_k8s_url(infra.k8s_api_url)?.to_string();
139 let v = require_vault(infra.vault_base_url)?.to_string();
140 service::session::validate_console_session(&infra, &ctx.token, &name)
141 .await
142 .map_err(to_handler_error)?;
143 (k, v, ctx.state.console_inactivity_timeout)
144 };
145
146 let k8s = K8sDetails {
147 api_url: k8s_api_url,
148 authentication: K8sAuth::Vault {
149 base_url: vault_base_url,
150 },
151 };
152
153 let RequestCtx {
155 state,
156 token,
157 site_name,
158 } = ctx;
159
160 Ok(ws.on_upgrade(move |socket| async move {
161 tracing::info!("WebSocket console opened for session {name}");
162 if let Some(site) = state.sites.get(&site_name) {
163 match site
164 .backend
165 .attach_to_session_console(
166 &token, &site_name, &name, q.cols, q.rows, &k8s,
167 )
168 .await
169 {
170 Ok((console_in, console_out)) => {
171 run_console_bridge(socket, console_in, console_out, timeout).await;
172 tracing::info!("WebSocket console closed for session {name}");
173 }
174 Err(e) => {
175 tracing::error!("Failed to attach to session console {name}: {e:#}");
176 }
177 }
178 }
179 }))
180}
181
182#[allow(async_fn_in_trait)]
188trait ConsoleSocket: Send + Unpin {
189 async fn recv(&mut self) -> Option<Result<Message, axum::Error>>;
190 async fn send(&mut self, msg: Message) -> Result<(), axum::Error>;
191}
192
193impl ConsoleSocket for WebSocket {
194 async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
195 WebSocket::recv(self).await
196 }
197 async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
198 WebSocket::send(self, msg).await
199 }
200}
201
202async fn run_console_bridge<S: ConsoleSocket>(
212 mut socket: S,
213 mut console_in: Box<dyn tokio::io::AsyncWrite + Unpin + Send>,
214 console_out: Box<dyn tokio::io::AsyncRead + Unpin + Send>,
215 inactivity_timeout: std::time::Duration,
216) {
217 let mut out_stream = tokio_util::io::ReaderStream::new(console_out);
218 let mut deadline = tokio::time::Instant::now() + inactivity_timeout;
219
220 loop {
221 tokio::select! {
222 msg = socket.recv() => {
223 match msg {
224 Some(Ok(Message::Binary(data))) => {
225 deadline = tokio::time::Instant::now() + inactivity_timeout;
226 if console_in.write_all(&data).await.is_err() { break; }
227 }
228 Some(Ok(Message::Text(text))) => {
229 deadline = tokio::time::Instant::now() + inactivity_timeout;
230 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&text)
232 && v.get("type").and_then(|t| t.as_str()) == Some("resize")
233 {
234 continue;
235 }
236 if console_in.write_all(text.as_bytes()).await.is_err() { break; }
237 }
238 Some(Ok(Message::Close(_))) | None => break,
239 Some(Ok(_)) => {} Some(Err(_)) => break,
241 }
242 }
243 chunk = out_stream.next() => {
244 match chunk {
245 Some(Ok(data)) => {
246 if socket.send(Message::Binary(data)).await.is_err() { break; }
247 }
248 Some(Err(_)) | None => break,
249 }
250 }
251 _ = tokio::time::sleep_until(deadline) => {
252 tracing::warn!("Console session idle for {:?}, closing", inactivity_timeout);
253 break;
254 }
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
271 use std::pin::Pin;
272 use std::sync::{Arc, Mutex};
273 use std::task::{Context, Poll};
274 use std::time::Duration;
275 use tokio::sync::mpsc;
276
277 struct CaptureWriter(Arc<Mutex<Vec<u8>>>);
280
281 impl tokio::io::AsyncWrite for CaptureWriter {
282 fn poll_write(
283 self: Pin<&mut Self>,
284 _: &mut Context<'_>,
285 buf: &[u8],
286 ) -> Poll<std::io::Result<usize>> {
287 self.0.lock().unwrap().extend_from_slice(buf);
288 Poll::Ready(Ok(buf.len()))
289 }
290 fn poll_flush(
291 self: Pin<&mut Self>,
292 _: &mut Context<'_>,
293 ) -> Poll<std::io::Result<()>> {
294 Poll::Ready(Ok(()))
295 }
296 fn poll_shutdown(
297 self: Pin<&mut Self>,
298 _: &mut Context<'_>,
299 ) -> Poll<std::io::Result<()>> {
300 Poll::Ready(Ok(()))
301 }
302 }
303
304 struct PendingReader;
309
310 impl tokio::io::AsyncRead for PendingReader {
311 fn poll_read(
312 self: Pin<&mut Self>,
313 _: &mut Context<'_>,
314 _: &mut tokio::io::ReadBuf<'_>,
315 ) -> Poll<std::io::Result<()>> {
316 Poll::Pending
317 }
318 }
319
320 struct MockSocket {
324 rx: mpsc::UnboundedReceiver<Result<Message, axum::Error>>,
325 tx: mpsc::UnboundedSender<Message>,
326 }
327
328 impl ConsoleSocket for MockSocket {
329 async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
330 self.rx.recv().await
331 }
332 async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
333 self.tx.send(msg).map_err(axum::Error::new)
334 }
335 }
336
337 fn new_mock_socket() -> (
338 MockSocket,
339 mpsc::UnboundedSender<Result<Message, axum::Error>>,
340 mpsc::UnboundedReceiver<Message>,
341 ) {
342 let (in_tx, in_rx) = mpsc::unbounded_channel();
343 let (out_tx, out_rx) = mpsc::unbounded_channel();
344 (
345 MockSocket {
346 rx: in_rx,
347 tx: out_tx,
348 },
349 in_tx,
350 out_rx,
351 )
352 }
353
354 async fn bridge_exited_within(
357 handle: &mut tokio::task::JoinHandle<()>,
358 cap: Duration,
359 ) -> bool {
360 tokio::select! {
361 _ = handle => true,
362 _ = tokio::time::sleep(cap) => false,
363 }
364 }
365
366 #[tokio::test(start_paused = true)]
367 async fn inactivity_timeout_fires_when_no_traffic() {
368 let (socket, _in_tx, _out_rx) = new_mock_socket();
369 let console_in = Box::new(tokio::io::sink());
370 let console_out = Box::new(PendingReader);
371
372 let mut handle = tokio::spawn(async move {
373 run_console_bridge(
374 socket,
375 console_in,
376 console_out,
377 Duration::from_secs(60),
378 )
379 .await
380 });
381
382 assert!(
384 !bridge_exited_within(&mut handle, Duration::from_secs(59)).await,
385 "bridge exited before the 60s inactivity timeout"
386 );
387 assert!(
389 bridge_exited_within(&mut handle, Duration::from_secs(5)).await,
390 "bridge did not exit after the inactivity timeout"
391 );
392 }
393
394 #[tokio::test(start_paused = true)]
395 async fn client_binary_message_resets_deadline() {
396 let (socket, in_tx, _out_rx) = new_mock_socket();
397 let console_in = Box::new(tokio::io::sink());
398 let console_out = Box::new(PendingReader);
399
400 let mut handle = tokio::spawn(async move {
401 run_console_bridge(
402 socket,
403 console_in,
404 console_out,
405 Duration::from_secs(60),
406 )
407 .await
408 });
409
410 tokio::time::sleep(Duration::from_secs(59)).await;
412 in_tx
413 .send(Ok(Message::Binary(b"hi".to_vec().into())))
414 .unwrap();
415 tokio::task::yield_now().await;
418
419 assert!(
421 !bridge_exited_within(&mut handle, Duration::from_secs(31)).await,
422 "deadline was not reset by client binary message"
423 );
424 assert!(
426 bridge_exited_within(&mut handle, Duration::from_secs(35)).await,
427 "bridge did not exit after the reset deadline"
428 );
429 }
430
431 #[tokio::test(start_paused = true)]
432 async fn resize_text_resets_deadline_but_is_not_forwarded() {
433 let (socket, in_tx, _out_rx) = new_mock_socket();
437 let written: Arc<Mutex<Vec<u8>>> = Default::default();
438 let console_in = Box::new(CaptureWriter(written.clone()));
439 let console_out = Box::new(PendingReader);
440
441 let mut handle = tokio::spawn(async move {
442 run_console_bridge(
443 socket,
444 console_in,
445 console_out,
446 Duration::from_secs(60),
447 )
448 .await
449 });
450
451 tokio::time::sleep(Duration::from_secs(59)).await;
452 in_tx
453 .send(Ok(Message::Text(
454 r#"{"type":"resize","cols":120,"rows":40}"#.into(),
455 )))
456 .unwrap();
457 tokio::task::yield_now().await;
458
459 assert!(
461 !bridge_exited_within(&mut handle, Duration::from_secs(30)).await,
462 "deadline was not reset by resize message"
463 );
464 assert!(
466 written.lock().unwrap().is_empty(),
467 "resize text frame was forwarded to console stdin (should be consumed)"
468 );
469
470 handle.abort();
471 }
472
473 #[tokio::test(start_paused = true)]
474 async fn client_close_exits_loop_immediately() {
475 let (socket, in_tx, _out_rx) = new_mock_socket();
476 let console_in = Box::new(tokio::io::sink());
477 let console_out = Box::new(PendingReader);
478
479 let mut handle = tokio::spawn(async move {
480 run_console_bridge(
481 socket,
482 console_in,
483 console_out,
484 Duration::from_secs(3600),
485 )
486 .await
487 });
488
489 in_tx.send(Ok(Message::Close(None))).unwrap();
490 assert!(
491 bridge_exited_within(&mut handle, Duration::from_secs(1)).await,
492 "bridge did not exit on Close frame"
493 );
494 }
495
496 #[tokio::test(start_paused = true)]
497 async fn server_to_client_data_does_not_reset_deadline() {
498 use tokio::io::AsyncReadExt;
504
505 let (socket, _in_tx, mut out_rx) = new_mock_socket();
506 let console_in = Box::new(tokio::io::sink());
507 let console_out =
512 Box::new(std::io::Cursor::new(b"chunk".to_vec()).chain(PendingReader));
513
514 let mut handle = tokio::spawn(async move {
515 run_console_bridge(
516 socket,
517 console_in,
518 console_out,
519 Duration::from_secs(60),
520 )
521 .await
522 });
523
524 tokio::spawn(async move { while out_rx.recv().await.is_some() {} });
526
527 assert!(
531 bridge_exited_within(&mut handle, Duration::from_secs(65)).await,
532 "server-to-client data should NOT keep the deadline alive"
533 );
534 }
535}