website/api/src/handlers.rs

394 lines
11 KiB
Rust

use std::sync::Arc;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::Json;
use base64::Engine;
use serde::{Deserialize, Serialize};
use crate::email;
use crate::serve::AppState;
#[derive(Serialize)]
pub struct Question {
id: i64,
question: String,
answer: String,
created_at: String,
answered_at: String,
}
pub async fn get_questions(
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<Question>>, StatusCode> {
let db = state
.db
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut stmt = db
.prepare(
"SELECT id, question, answer, created_at, answered_at \
FROM questions WHERE answer IS NOT NULL \
ORDER BY answered_at DESC",
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let questions = stmt
.query_map([], |row| {
Ok(Question {
id: row.get(0)?,
question: row.get(1)?,
answer: row.get(2)?,
created_at: row.get(3)?,
answered_at: row.get(4)?,
})
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.collect::<Result<Vec<_>, _>>()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(questions))
}
#[derive(Deserialize)]
pub struct SubmitQuestion {
question: String,
}
pub async fn post_question(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(body): Json<SubmitQuestion>,
) -> Result<StatusCode, (StatusCode, String)> {
if body.question.is_empty() || body.question.len() > 200 {
return Err((
StatusCode::BAD_REQUEST,
"Question must be 1-200 characters".to_string(),
));
}
let ip = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
if !state.rate_limiter.check(&ip) {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Too many questions. Try again later.".to_string(),
));
}
let id: i64 = {
let db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "db error".to_string()))?;
db.execute(
"INSERT INTO questions (question) VALUES (?1)",
rusqlite::params![body.question],
)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"insert error".to_string(),
)
})?;
db.last_insert_rowid()
};
let notify_email = state.notify_email.clone();
let mail_domain = state.mail_domain.clone();
let qa_reply_domain = state.qa_reply_domain.clone();
let question_text = body.question.clone();
tokio::task::spawn_blocking(move || {
if let Err(e) = email::send_notification(
id,
&question_text,
&notify_email,
&mail_domain,
&qa_reply_domain,
) {
eprintln!("Failed to send notification: {e}");
}
});
Ok(StatusCode::CREATED)
}
// --- MTA Hook webhook types ---
#[derive(Deserialize)]
pub struct MtaHookPayload {
#[serde(default)]
pub messages: Vec<MtaHookMessage>,
#[serde(default)]
pub envelope: Envelope,
#[serde(default)]
pub message: MtaHookBody,
}
#[derive(Deserialize)]
pub struct MtaHookMessage {
#[serde(default)]
pub envelope: Envelope,
#[serde(default)]
pub message: MtaHookBody,
#[serde(default)]
pub contents: String,
}
#[derive(Deserialize, Default)]
pub struct Envelope {
#[serde(default)]
pub to: Vec<Recipient>,
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum Recipient {
Address(String),
WithAddress { address: String },
}
impl Default for Recipient {
fn default() -> Self {
Self::Address(String::new())
}
}
impl Recipient {
fn address(&self) -> &str {
match self {
Self::Address(address) => address,
Self::WithAddress { address } => address,
}
}
}
#[derive(Deserialize, Default)]
pub struct MtaHookBody {
#[serde(default)]
pub subject: Option<String>,
#[serde(default)]
pub headers: MessageHeaders,
#[serde(default)]
pub contents: String,
}
#[derive(Deserialize, Default)]
pub struct MessageHeaders {
#[serde(default)]
pub subject: Option<String>,
}
#[derive(Serialize)]
pub struct MtaHookResponse {
pub action: &'static str,
}
fn webhook_secret_matches(headers: &HeaderMap, expected_secret: &str) -> bool {
let header_secret = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if header_secret == expected_secret {
return true;
}
let auth_header = match headers.get(axum::http::header::AUTHORIZATION) {
Some(value) => value,
None => return false,
};
let auth_header = match auth_header.to_str() {
Ok(value) => value,
Err(_) => return false,
};
let encoded = match auth_header.strip_prefix("Basic ") {
Some(value) => value,
None => return false,
};
let decoded = match base64::engine::general_purpose::STANDARD.decode(encoded) {
Ok(value) => value,
Err(_) => return false,
};
let credentials = match std::str::from_utf8(&decoded) {
Ok(value) => value,
Err(_) => return false,
};
let (_, password) = match credentials.split_once(':') {
Some(parts) => parts,
None => return false,
};
password == expected_secret
}
fn extract_qa_reply(payload: &MtaHookPayload, expected_domain: &str) -> Option<(i64, String)> {
if !payload.messages.is_empty() {
for message in &payload.messages {
if let Some(reply) = extract_qa_reply_from_message(
&message.envelope.to,
expected_domain,
message.message.subject.as_deref().or(message.message.headers.subject.as_deref()),
if message.message.contents.is_empty() {
&message.contents
} else {
&message.message.contents
},
) {
return Some(reply);
}
}
return None;
}
extract_qa_reply_from_message(
&payload.envelope.to,
expected_domain,
payload
.message
.subject
.as_deref()
.or(payload.message.headers.subject.as_deref()),
&payload.message.contents,
)
}
fn extract_qa_reply_from_message(
recipients: &[Recipient],
expected_domain: &str,
subject: Option<&str>,
contents: &str,
) -> Option<(i64, String)> {
let _qa_recipient = recipients.iter().find(|recipient| {
let address = recipient.address();
let Some((local, domain)) = address.rsplit_once('@') else {
return false;
};
local.eq_ignore_ascii_case("qa") && domain.eq_ignore_ascii_case(expected_domain)
})?;
let subject = subject.map(ToOwned::to_owned).or_else(|| {
contents
.replace("\r\n", "\n")
.lines()
.find_map(|line| line.strip_prefix("Subject: ").map(ToOwned::to_owned))
})?;
let id = email::extract_id_from_subject(&subject).ok()?;
let body = email::extract_plain_text_body(contents);
if body.is_empty() {
return None;
}
Some((id, body))
}
pub async fn webhook(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(payload): Json<MtaHookPayload>,
) -> Result<Json<MtaHookResponse>, (StatusCode, String)> {
if !webhook_secret_matches(&headers, &state.webhook_secret) {
eprintln!("Rejected webhook: invalid secret");
return Err((StatusCode::UNAUTHORIZED, "invalid secret".to_string()));
}
if let Some((id, body)) = extract_qa_reply(&payload, &state.qa_reply_domain) {
let db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "db error".to_string()))?;
db.execute(
"UPDATE questions SET answer = ?1, answered_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
WHERE id = ?2 AND answer IS NULL",
rusqlite::params![body, id],
)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "update error".to_string()))?;
eprintln!("Stored Q&A reply for question #{id}");
return Ok(Json(MtaHookResponse { action: "discard" }));
}
// No Q&A recipient matched — let Stalwart deliver normally
Ok(Json(MtaHookResponse { action: "accept" }))
}
#[cfg(test)]
mod tests {
use axum::http::HeaderMap;
use super::{extract_qa_reply, webhook_secret_matches, MtaHookPayload};
#[test]
fn extracts_reply_from_current_stalwart_payload() {
let payload: MtaHookPayload = serde_json::from_str(
r#"{
"envelope": {
"to": [
{
"address": "qa@extremist.software"
}
]
},
"message": {
"subject": "Re: 42 - hello",
"contents": "This is the answer.\n\nOn earlier mail wrote:\n> quoted"
}
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply(&payload, "extremist.software"),
Some((42, "This is the answer.".to_string()))
);
}
#[test]
fn extracts_reply_from_legacy_batch_payload() {
let payload: MtaHookPayload = serde_json::from_str(
r#"{
"messages": [
{
"envelope": {
"to": ["qa@extremist.software"]
},
"message": {
"subject": "Re: 7 - legacy"
},
"contents": "Legacy answer"
}
]
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply(&payload, "extremist.software"),
Some((7, "Legacy answer".to_string()))
);
}
#[test]
fn accepts_header_secret() {
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", "topsecret".parse().unwrap());
assert!(webhook_secret_matches(&headers, "topsecret"));
}
#[test]
fn accepts_basic_auth_password() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Basic dXNlcjp0b3BzZWNyZXQ=".parse().unwrap(),
);
assert!(webhook_secret_matches(&headers, "topsecret"));
}
}