Add websocket handler on server, connect from client

Additionally add /test handler that triggers server->client WS message
This commit is contained in:
2025-04-14 20:46:52 -07:00
parent b2c73ffa15
commit f2042f284e
9 changed files with 357 additions and 19 deletions

View File

@@ -1,15 +1,19 @@
// Rocket generates a lot of warnings for handlers
// TODO: figure out why
#![allow(unreachable_patterns)]
use std::{error::Error, io::Cursor, str::FromStr};
use std::{error::Error, io::Cursor, net::SocketAddr, str::FromStr, sync::Arc};
use async_graphql::{extensions, http::GraphiQLSource, Schema};
use async_graphql_axum::{GraphQL, GraphQLSubscription};
//allows to extract the IP of connecting user
use axum::extract::connect_info::ConnectInfo;
use axum::{
extract::{ws::WebSocketUpgrade, State},
response::{self, IntoResponse},
routing::get,
routing::{any, get},
Router,
};
use axum_extra::TypedHeader;
use cacher::FilesystemCacher;
use letterbox_notmuch::{Notmuch, NotmuchError, ThreadSet};
#[cfg(feature = "tantivy")]
@@ -19,10 +23,13 @@ use letterbox_server::{
error::ServerError,
graphql::{Attachment, GraphqlSchema, MutationRoot, QueryRoot, SubscriptionRoot},
nm::{attachment_bytes, cid_attachment_bytes},
ws::ConnectionTracker,
};
use letterbox_shared::WebsocketMessage;
use sqlx::postgres::PgPool;
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;
use tokio::{net::TcpListener, sync::Mutex};
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
use tracing::{error, info};
/*
#[get("/show/<query>/pretty")]
@@ -245,36 +252,85 @@ async fn main() -> Result<(), Box<dyn Error>> {
async fn graphiql() -> impl IntoResponse {
response::Html(
GraphiQLSource::build()
.endpoint("/api/")
.subscription_endpoint("/api/ws")
.endpoint("/api/graphql/")
.subscription_endpoint("/api/graphql/ws")
.finish(),
)
}
async fn start_ws(
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(connection_tracker): State<Arc<Mutex<ConnectionTracker>>>,
) -> impl IntoResponse {
ws.on_upgrade(async move |socket| connection_tracker.lock().await.add_peer(socket, addr))
}
#[axum_macros::debug_handler]
async fn test_handler(
State(connection_tracker): State<Arc<Mutex<ConnectionTracker>>>,
) -> impl IntoResponse {
connection_tracker
.lock()
.await
.send_message_all(WebsocketMessage::RefreshMessages)
.await;
"test triggered"
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _guard = xtracing::init(env!("CARGO_BIN_NAME"))?;
build_info::build_info!(fn bi);
info!("Build Info: {}", letterbox_shared::build_version(bi));
// TODO: move these to config
let port = 9345;
let config = Config {
newsreader_database_url: "postgres://newsreader@nixos-07.h.xinu.tv/newsreader".to_string(),
newsreader_tantivy_db_path: "../target/database/newsreader".to_string(),
slurp_cache_path: "/tmp/letterbox/slurp".to_string(),
};
if !std::fs::exists(&config.slurp_cache_path)? {
info!("Creating slurp cache @ '{}'", &config.slurp_cache_path);
std::fs::create_dir_all(&config.slurp_cache_path)?;
}
let pool = PgPool::connect(&config.newsreader_database_url).await?;
sqlx::migrate!("./migrations").run(&pool).await?;
#[cfg(feature = "tantivy")]
let tantivy_conn = TantivyConnection::new(&config.newsreader_tantivy_db_path)?;
let cacher = FilesystemCacher::new(&config.slurp_cache_path)?;
let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
//.data(Storage::default())
.finish();
.data(Notmuch::default())
.data(cacher)
.data(pool.clone());
let schema = schema.extension(extensions::Logger).finish();
let conn_tracker = Arc::new(Mutex::new(ConnectionTracker::default()));
let app = Router::new()
.route("/test", get(test_handler))
.route("/api/ws", any(start_ws))
.route_service("/api/graphql/ws", GraphQLSubscription::new(schema.clone()))
.route(
"/api/",
"/api/graphql/",
get(graphiql).post_service(GraphQL::new(schema.clone())),
)
.route_service("/api/ws", GraphQLSubscription::new(schema))
.with_state(conn_tracker)
.layer(
TraceLayer::new_for_http()
.on_request(tower_http::trace::DefaultOnRequest::new().level(tracing::Level::INFO))
.on_response(
tower_http::trace::DefaultOnResponse::new().level(tracing::Level::INFO),
)
.on_failure(tower_http::trace::DefaultOnFailure::new().level(tracing::Level::WARN)),
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
);
axum::serve(TcpListener::bind("0.0.0.0:9345").await.unwrap(), app)
let listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], port)))
.await
.unwrap();
tracing::info!("listening on {}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
Ok(())
}

View File

@@ -4,6 +4,8 @@ pub mod graphql;
pub mod mail;
pub mod newsreader;
pub mod nm;
pub mod ws;
#[cfg(feature = "tantivy")]
pub mod tantivy;

32
server/src/ws.rs Normal file
View File

@@ -0,0 +1,32 @@
use std::{collections::HashMap, net::SocketAddr};
use axum::extract::ws::{Message, WebSocket};
use letterbox_shared::WebsocketMessage;
use tracing::{info, warn};
#[derive(Default)]
pub struct ConnectionTracker {
peers: HashMap<SocketAddr, WebSocket>,
}
impl ConnectionTracker {
pub fn add_peer(&mut self, socket: WebSocket, who: SocketAddr) {
warn!("adding {who:?} to connection tracker");
self.peers.insert(who, socket);
}
pub async fn send_message_all(&mut self, msg: WebsocketMessage) {
let m = serde_json::to_string(&msg).expect("failed to json encode WebsocketMessage");
let mut bad_peers = Vec::new();
for (who, socket) in &mut self.peers.iter_mut() {
if let Err(e) = socket.send(Message::Text(m.clone().into())).await {
warn!("{:?} is bad, scheduling for removal: {e}", who);
bad_peers.push(who.clone());
}
}
for b in bad_peers {
info!("removing bad peer {b:?}");
self.peers.remove(&b);
}
}
}