Add websocket handler on server, connect from client
Additionally add /test handler that triggers server->client WS message
This commit is contained in:
@@ -17,13 +17,16 @@ anyhow = "1.0.79"
|
||||
async-graphql = { version = "7", features = ["log"] }
|
||||
async-graphql-axum = "7.0.15"
|
||||
async-trait = "0.1.81"
|
||||
axum = "0.8.1"
|
||||
axum = { version = "0.8.3", features = ["ws"] }
|
||||
axum-extra = { version = "0.10.1", features = ["typed-header"] }
|
||||
axum-macros = "0.5.0"
|
||||
build-info = "0.0.40"
|
||||
cacher = { version = "0.2.0", registry = "xinu" }
|
||||
chrono = "0.4.39"
|
||||
clap = { version = "4.5.23", features = ["derive"] }
|
||||
css-inline = "0.14.0"
|
||||
futures = "0.3.31"
|
||||
headers = "0.4.0"
|
||||
html-escape = "0.2.13"
|
||||
letterbox-notmuch = { version = "0.12.1", path = "../notmuch", registry = "xinu" }
|
||||
letterbox-shared = { version = "0.12.1", path = "../shared", registry = "xinu" }
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
// Rocket generates a lot of warnings for handlers
|
||||
// TODO: figure out why
|
||||
#![allow(unreachable_patterns)]
|
||||
use std::{error::Error, io::Cursor, str::FromStr};
|
||||
use std::{error::Error, io::Cursor, net::SocketAddr, str::FromStr, sync::Arc};
|
||||
|
||||
use async_graphql::{extensions, http::GraphiQLSource, Schema};
|
||||
use async_graphql_axum::{GraphQL, GraphQLSubscription};
|
||||
//allows to extract the IP of connecting user
|
||||
use axum::extract::connect_info::ConnectInfo;
|
||||
use axum::{
|
||||
extract::{ws::WebSocketUpgrade, State},
|
||||
response::{self, IntoResponse},
|
||||
routing::get,
|
||||
routing::{any, get},
|
||||
Router,
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use cacher::FilesystemCacher;
|
||||
use letterbox_notmuch::{Notmuch, NotmuchError, ThreadSet};
|
||||
#[cfg(feature = "tantivy")]
|
||||
@@ -19,10 +23,13 @@ use letterbox_server::{
|
||||
error::ServerError,
|
||||
graphql::{Attachment, GraphqlSchema, MutationRoot, QueryRoot, SubscriptionRoot},
|
||||
nm::{attachment_bytes, cid_attachment_bytes},
|
||||
ws::ConnectionTracker,
|
||||
};
|
||||
use letterbox_shared::WebsocketMessage;
|
||||
use sqlx::postgres::PgPool;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tokio::{net::TcpListener, sync::Mutex};
|
||||
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
|
||||
use tracing::{error, info};
|
||||
|
||||
/*
|
||||
#[get("/show/<query>/pretty")]
|
||||
@@ -245,36 +252,85 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
||||
async fn graphiql() -> impl IntoResponse {
|
||||
response::Html(
|
||||
GraphiQLSource::build()
|
||||
.endpoint("/api/")
|
||||
.subscription_endpoint("/api/ws")
|
||||
.endpoint("/api/graphql/")
|
||||
.subscription_endpoint("/api/graphql/ws")
|
||||
.finish(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn start_ws(
|
||||
ws: WebSocketUpgrade,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
State(connection_tracker): State<Arc<Mutex<ConnectionTracker>>>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(async move |socket| connection_tracker.lock().await.add_peer(socket, addr))
|
||||
}
|
||||
#[axum_macros::debug_handler]
|
||||
async fn test_handler(
|
||||
State(connection_tracker): State<Arc<Mutex<ConnectionTracker>>>,
|
||||
) -> impl IntoResponse {
|
||||
connection_tracker
|
||||
.lock()
|
||||
.await
|
||||
.send_message_all(WebsocketMessage::RefreshMessages)
|
||||
.await;
|
||||
"test triggered"
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _guard = xtracing::init(env!("CARGO_BIN_NAME"))?;
|
||||
build_info::build_info!(fn bi);
|
||||
info!("Build Info: {}", letterbox_shared::build_version(bi));
|
||||
// TODO: move these to config
|
||||
let port = 9345;
|
||||
let config = Config {
|
||||
newsreader_database_url: "postgres://newsreader@nixos-07.h.xinu.tv/newsreader".to_string(),
|
||||
newsreader_tantivy_db_path: "../target/database/newsreader".to_string(),
|
||||
slurp_cache_path: "/tmp/letterbox/slurp".to_string(),
|
||||
};
|
||||
if !std::fs::exists(&config.slurp_cache_path)? {
|
||||
info!("Creating slurp cache @ '{}'", &config.slurp_cache_path);
|
||||
std::fs::create_dir_all(&config.slurp_cache_path)?;
|
||||
}
|
||||
let pool = PgPool::connect(&config.newsreader_database_url).await?;
|
||||
sqlx::migrate!("./migrations").run(&pool).await?;
|
||||
#[cfg(feature = "tantivy")]
|
||||
let tantivy_conn = TantivyConnection::new(&config.newsreader_tantivy_db_path)?;
|
||||
|
||||
let cacher = FilesystemCacher::new(&config.slurp_cache_path)?;
|
||||
let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
|
||||
//.data(Storage::default())
|
||||
.finish();
|
||||
.data(Notmuch::default())
|
||||
.data(cacher)
|
||||
.data(pool.clone());
|
||||
|
||||
let schema = schema.extension(extensions::Logger).finish();
|
||||
|
||||
let conn_tracker = Arc::new(Mutex::new(ConnectionTracker::default()));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/test", get(test_handler))
|
||||
.route("/api/ws", any(start_ws))
|
||||
.route_service("/api/graphql/ws", GraphQLSubscription::new(schema.clone()))
|
||||
.route(
|
||||
"/api/",
|
||||
"/api/graphql/",
|
||||
get(graphiql).post_service(GraphQL::new(schema.clone())),
|
||||
)
|
||||
.route_service("/api/ws", GraphQLSubscription::new(schema))
|
||||
.with_state(conn_tracker)
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.on_request(tower_http::trace::DefaultOnRequest::new().level(tracing::Level::INFO))
|
||||
.on_response(
|
||||
tower_http::trace::DefaultOnResponse::new().level(tracing::Level::INFO),
|
||||
)
|
||||
.on_failure(tower_http::trace::DefaultOnFailure::new().level(tracing::Level::WARN)),
|
||||
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||
);
|
||||
|
||||
axum::serve(TcpListener::bind("0.0.0.0:9345").await.unwrap(), app)
|
||||
let listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], port)))
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::info!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ pub mod graphql;
|
||||
pub mod mail;
|
||||
pub mod newsreader;
|
||||
pub mod nm;
|
||||
pub mod ws;
|
||||
|
||||
#[cfg(feature = "tantivy")]
|
||||
pub mod tantivy;
|
||||
|
||||
|
||||
32
server/src/ws.rs
Normal file
32
server/src/ws.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::{collections::HashMap, net::SocketAddr};
|
||||
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use letterbox_shared::WebsocketMessage;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ConnectionTracker {
|
||||
peers: HashMap<SocketAddr, WebSocket>,
|
||||
}
|
||||
|
||||
impl ConnectionTracker {
|
||||
pub fn add_peer(&mut self, socket: WebSocket, who: SocketAddr) {
|
||||
warn!("adding {who:?} to connection tracker");
|
||||
self.peers.insert(who, socket);
|
||||
}
|
||||
pub async fn send_message_all(&mut self, msg: WebsocketMessage) {
|
||||
let m = serde_json::to_string(&msg).expect("failed to json encode WebsocketMessage");
|
||||
let mut bad_peers = Vec::new();
|
||||
for (who, socket) in &mut self.peers.iter_mut() {
|
||||
if let Err(e) = socket.send(Message::Text(m.clone().into())).await {
|
||||
warn!("{:?} is bad, scheduling for removal: {e}", who);
|
||||
bad_peers.push(who.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for b in bad_peers {
|
||||
info!("removing bad peer {b:?}");
|
||||
self.peers.remove(&b);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user