noisebell/remote/discord-bot/src/main.rs

264 lines
8.9 KiB
Rust

use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use axum::extract::State as AxumState;
use axum::http::{HeaderMap, StatusCode};
use axum::routing::{get, post};
use axum::{Json, Router};
use noisebell_common::{validate_bearer, CacheStatusResponse, DoorStatus, WebhookPayload};
use serenity::all::{
ChannelId, Colour, CommandInteraction, CreateCommand, CreateEmbed, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateMessage, GatewayIntents, Interaction,
};
use serenity::async_trait;
use tower_http::trace::TraceLayer;
use tracing::{error, info, warn, Level};
struct AppState {
http: Arc<serenity::all::Http>,
channel_id: ChannelId,
webhook_secret: String,
image_base_url: String,
cache_url: String,
client: reqwest::Client,
}
fn build_embed(status: DoorStatus, timestamp: u64, image_base_url: &str) -> CreateEmbed {
let (colour, title, description, image_file) = match status {
DoorStatus::Open => (
Colour::from_rgb(0, 255, 0),
"Noisebridge is Open!",
"It's time to start hacking.",
"open.png",
),
DoorStatus::Closed => (
Colour::from_rgb(255, 0, 0),
"Noisebridge is Closed!",
"We'll see you again soon.",
"closed.png",
),
DoorStatus::Offline => (
Colour::from_rgb(153, 170, 181),
"Noisebridge is Offline",
"The Noisebridge Pi is not responding.",
"offline.png",
),
};
let image_url = format!("{image_base_url}/{image_file}");
CreateEmbed::new()
.title(title)
.description(description)
.colour(colour)
.field("Since", format_timestamp(timestamp), true)
.thumbnail(image_url)
.timestamp(
serenity::model::Timestamp::from_unix_timestamp(timestamp as i64)
.unwrap_or_else(|_| serenity::model::Timestamp::now()),
)
}
async fn post_webhook(
AxumState(state): AxumState<Arc<AppState>>,
headers: HeaderMap,
Json(body): Json<WebhookPayload>,
) -> StatusCode {
if !validate_bearer(&headers, &state.webhook_secret) {
return StatusCode::UNAUTHORIZED;
}
info!(status = %body.status, timestamp = body.timestamp, "received webhook");
let embed = build_embed(body.status, body.timestamp, &state.image_base_url);
let message = CreateMessage::new().embed(embed);
match state.channel_id.send_message(&state.http, message).await {
Ok(_) => {
info!(status = %body.status, "embed sent to Discord");
StatusCode::OK
}
Err(e) => {
error!(error = %e, "failed to send embed to Discord");
StatusCode::INTERNAL_SERVER_ERROR
}
}
}
fn unix_now() -> u64 {
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()
}
fn format_timestamp(ts: u64) -> String {
format!("<t:{}:R>", ts)
}
async fn handle_status(
state: &AppState,
_command: &CommandInteraction,
) -> CreateInteractionResponse {
let url = format!("{}/status", state.cache_url);
let resp = state.client.get(&url).send().await;
let embed = match resp {
Ok(resp) if resp.status().is_success() => match resp.json::<CacheStatusResponse>().await {
Ok(data) => {
build_embed(data.status, data.since.unwrap_or(unix_now()), &state.image_base_url)
}
Err(e) => {
error!(error = %e, "failed to parse status response");
CreateEmbed::new()
.title("Error")
.description("Failed to parse status response.")
.colour(Colour::from_rgb(255, 0, 0))
}
},
_ => CreateEmbed::new()
.title("Error")
.description("Failed to reach the cache service.")
.colour(Colour::from_rgb(255, 0, 0)),
};
CreateInteractionResponse::Message(CreateInteractionResponseMessage::new().embed(embed))
}
struct Handler {
state: Arc<AppState>,
}
#[async_trait]
impl serenity::all::EventHandler for Handler {
async fn ready(&self, ctx: serenity::all::Context, ready: serenity::model::gateway::Ready) {
info!(user = %ready.user.name, "Discord bot connected");
let commands =
vec![CreateCommand::new("status").description("Show the current door status")];
if let Err(e) = serenity::all::Command::set_global_commands(&ctx.http, commands).await {
error!(error = %e, "failed to register slash commands");
} else {
info!("slash commands registered");
}
}
async fn interaction_create(&self, ctx: serenity::all::Context, interaction: Interaction) {
if let Interaction::Command(command) = interaction {
let response = match command.data.name.as_str() {
"status" => handle_status(&self.state, &command).await,
_ => CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new().content("Unknown command."),
),
};
if let Err(e) = command.create_response(&ctx.http, response).await {
error!(error = %e, command = %command.data.name, "failed to respond to slash command");
}
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
let discord_token =
std::env::var("NOISEBELL_DISCORD_TOKEN").context("NOISEBELL_DISCORD_TOKEN is required")?;
let channel_id: u64 = std::env::var("NOISEBELL_DISCORD_CHANNEL_ID")
.context("NOISEBELL_DISCORD_CHANNEL_ID is required")?
.parse()
.context("NOISEBELL_DISCORD_CHANNEL_ID must be a valid u64")?;
let webhook_secret = std::env::var("NOISEBELL_DISCORD_WEBHOOK_SECRET")
.context("NOISEBELL_DISCORD_WEBHOOK_SECRET is required")?;
let port: u16 = std::env::var("NOISEBELL_DISCORD_PORT")
.unwrap_or_else(|_| "3001".into())
.parse()
.context("NOISEBELL_DISCORD_PORT must be a valid u16")?;
let image_base_url = std::env::var("NOISEBELL_DISCORD_IMAGE_BASE_URL")
.unwrap_or_else(|_| "https://noisebell.extremist.software/image".into())
.trim_end_matches('/')
.to_string();
let cache_url = std::env::var("NOISEBELL_DISCORD_CACHE_URL")
.context("NOISEBELL_DISCORD_CACHE_URL is required")?
.trim_end_matches('/')
.to_string();
info!(port, channel_id, "starting noisebell-discord");
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.context("failed to build HTTP client")?;
let intents = GatewayIntents::empty();
let mut discord_client = serenity::Client::builder(&discord_token, intents)
.event_handler_arc(Arc::new(Handler {
state: Arc::new(AppState {
http: Arc::new(serenity::all::Http::new(&discord_token)),
channel_id: ChannelId::new(channel_id),
webhook_secret: webhook_secret.clone(),
image_base_url: image_base_url.clone(),
cache_url: cache_url.clone(),
client: client.clone(),
}),
}))
.await
.context("failed to create Discord client")?;
let http = discord_client.http.clone();
let app_state = Arc::new(AppState {
http,
channel_id: ChannelId::new(channel_id),
webhook_secret,
image_base_url,
cache_url,
client,
});
let app = Router::new()
.route("/health", get(|| async { StatusCode::OK }))
.route("/webhook", post(post_webhook))
.layer(
TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(tower_http::trace::DefaultOnResponse::new().level(Level::INFO)),
)
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(("0.0.0.0", port))
.await
.context(format!("failed to bind to 0.0.0.0:{port}"))?;
info!(port, "webhook listener ready");
// Spawn gateway connection for slash commands
tokio::spawn(async move {
loop {
if let Err(e) = discord_client.start().await {
error!(error = %e, "Discord gateway disconnected");
}
warn!("reconnecting to Discord gateway in 5s");
tokio::time::sleep(Duration::from_secs(5)).await;
}
});
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.context("failed to register SIGTERM handler")?;
axum::serve(listener, app)
.with_graceful_shutdown(async move {
sigterm.recv().await;
})
.await
.context("server error")?;
info!("shutdown complete");
Ok(())
}