website/api/src/handlers.rs

692 lines
20 KiB
Rust

use std::sync::Arc;
use axum::extract::State;
use axum::http::header::CONTENT_TYPE;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use base64::Engine;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::email;
use crate::serve::AppState;
#[derive(Serialize)]
pub struct Question {
id: i64,
question: String,
answer: String,
created_at: String,
answered_at: String,
}
#[derive(Serialize)]
pub struct QuestionStats {
asked: i64,
answered: i64,
}
const SITE_URL: &str = "https://jetpham.com";
fn xml_escape(text: &str) -> String {
text.replace('&', "&")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
fn rss_pub_date(timestamp: &str) -> String {
DateTime::parse_from_rfc3339(timestamp)
.map(|dt| dt.to_rfc2822())
.unwrap_or_else(|_| Utc::now().to_rfc2822())
}
pub async fn get_questions(
State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<Question>>, StatusCode> {
let db = state
.db
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut stmt = db
.prepare(
"SELECT id, question, answer, created_at, answered_at \
FROM questions WHERE answer IS NOT NULL \
ORDER BY answered_at DESC",
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let questions = stmt
.query_map([], |row| {
Ok(Question {
id: row.get(0)?,
question: row.get(1)?,
answer: row.get(2)?,
created_at: row.get(3)?,
answered_at: row.get(4)?,
})
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.collect::<Result<Vec<_>, _>>()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(questions))
}
pub async fn get_question_stats(
State(state): State<Arc<AppState>>,
) -> Result<Json<QuestionStats>, StatusCode> {
let db = state
.db
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let asked: i64 = db
.query_row("SELECT COUNT(*) FROM questions", [], |row| row.get(0))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let answered: i64 = db
.query_row(
"SELECT COUNT(*) FROM questions WHERE answer IS NOT NULL",
[],
|row| row.get(0),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(QuestionStats { asked, answered }))
}
pub async fn get_question_rss(
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, StatusCode> {
let db = state
.db
.lock()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let mut stmt = db
.prepare(
"SELECT id, question, answer, created_at, answered_at \
FROM questions WHERE answer IS NOT NULL \
ORDER BY answered_at DESC",
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let questions = stmt
.query_map([], |row| {
Ok(Question {
id: row.get(0)?,
question: row.get(1)?,
answer: row.get(2)?,
created_at: row.get(3)?,
answered_at: row.get(4)?,
})
})
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.collect::<Result<Vec<_>, _>>()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let items = questions
.into_iter()
.map(|question| {
let guid = format!("{SITE_URL}/qa#question-{}", question.id);
let description = format!(
"Question: {}\n\nAnswer: {}",
question.question, question.answer
);
format!(
"<item><title>{}</title><link>{}</link><guid>{}</guid><pubDate>{}</pubDate><description>{}</description></item>",
xml_escape(&question.question),
xml_escape(&guid),
xml_escape(&guid),
xml_escape(&rss_pub_date(&question.answered_at)),
xml_escape(&description),
)
})
.collect::<Vec<_>>()
.join("");
let xml = format!(
"<?xml version=\"1.0\" encoding=\"UTF-8\"?><rss version=\"2.0\"><channel><title>Jet Pham Q+A</title><link>{SITE_URL}/qa</link><description>Answered questions from Jet Pham&apos;s site</description><language>en-us</language>{items}</channel></rss>"
);
Ok(([(CONTENT_TYPE, "application/rss+xml; charset=utf-8")], xml))
}
#[derive(Deserialize)]
pub struct SubmitQuestion {
question: String,
}
pub async fn post_question(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(body): Json<SubmitQuestion>,
) -> Result<StatusCode, (StatusCode, String)> {
if body.question.is_empty() || body.question.len() > 200 {
return Err((
StatusCode::BAD_REQUEST,
"Question must be 1-200 characters".to_string(),
));
}
let ip = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string());
if !state.rate_limiter.check(&ip) {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Too many questions. Try again later.".to_string(),
));
}
let id: i64 = {
let db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "db error".to_string()))?;
db.execute(
"INSERT INTO questions (question) VALUES (?1)",
rusqlite::params![body.question],
)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"insert error".to_string(),
)
})?;
db.last_insert_rowid()
};
let notify_email = state.notify_email.clone();
let mail_domain = state.mail_domain.clone();
let qa_reply_domain = state.qa_reply_domain.clone();
let question_text = body.question.clone();
tokio::task::spawn_blocking(move || {
if let Err(e) = email::send_notification(
id,
&question_text,
&notify_email,
&mail_domain,
&qa_reply_domain,
) {
eprintln!("Failed to send notification: {e}");
}
});
Ok(StatusCode::CREATED)
}
// --- MTA Hook webhook types ---
#[derive(Deserialize)]
pub struct MtaHookPayload {
#[serde(default)]
pub messages: Vec<MtaHookMessage>,
#[serde(default)]
pub envelope: Envelope,
#[serde(default)]
pub message: MtaHookBody,
}
#[derive(Deserialize)]
pub struct MtaHookMessage {
#[serde(default)]
pub envelope: Envelope,
#[serde(default)]
pub message: MtaHookBody,
#[serde(default)]
pub contents: String,
}
#[derive(Deserialize, Default)]
pub struct Envelope {
#[serde(default)]
pub to: Vec<Recipient>,
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum Recipient {
Address(String),
WithAddress { address: String },
}
impl Default for Recipient {
fn default() -> Self {
Self::Address(String::new())
}
}
impl Recipient {
fn address(&self) -> &str {
match self {
Self::Address(address) => address,
Self::WithAddress { address } => address,
}
}
}
#[derive(Deserialize, Default)]
pub struct MtaHookBody {
#[serde(default)]
pub subject: Option<String>,
#[serde(default)]
pub headers: MessageHeaders,
#[serde(default)]
pub contents: String,
}
#[derive(Deserialize, Default)]
pub struct MessageHeaders {
#[serde(default)]
pub subject: Option<String>,
}
#[derive(Serialize)]
pub struct MtaHookResponse {
pub action: &'static str,
}
fn webhook_secret_matches(headers: &HeaderMap, expected_secret: &str) -> bool {
let expected_secret = expected_secret.trim();
let header_secret = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.trim();
if header_secret == expected_secret {
return true;
}
let auth_header = match headers.get(axum::http::header::AUTHORIZATION) {
Some(value) => value,
None => return false,
};
let auth_header = match auth_header.to_str() {
Ok(value) => value,
Err(_) => return false,
};
let encoded = match auth_header.strip_prefix("Basic ") {
Some(value) => value,
None => return false,
};
let decoded = match base64::engine::general_purpose::STANDARD.decode(encoded) {
Ok(value) => value,
Err(_) => return false,
};
let credentials = match std::str::from_utf8(&decoded) {
Ok(value) => value,
Err(_) => return false,
};
let (_, password) = match credentials.split_once(':') {
Some(parts) => parts,
None => return false,
};
password.trim() == expected_secret
}
fn webhook_secret_debug(headers: &HeaderMap) -> String {
let header_secret = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let auth = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let decoded = auth
.strip_prefix("Basic ")
.and_then(|encoded| base64::engine::general_purpose::STANDARD.decode(encoded).ok())
.and_then(|bytes| String::from_utf8(bytes).ok())
.unwrap_or_default();
format!(
"x-webhook-secret={header_secret:?}; authorization={auth:?}; basic-decoded={decoded:?}"
)
}
fn extract_qa_reply(payload: &MtaHookPayload, expected_domain: &str) -> Option<(i64, String)> {
if !payload.messages.is_empty() {
for message in &payload.messages {
if let Some(reply) = extract_qa_reply_from_message(
&message.envelope.to,
expected_domain,
message.message.subject.as_deref().or(message.message.headers.subject.as_deref()),
if message.message.contents.is_empty() {
&message.contents
} else {
&message.message.contents
},
) {
return Some(reply);
}
}
return None;
}
extract_qa_reply_from_message(
&payload.envelope.to,
expected_domain,
payload
.message
.subject
.as_deref()
.or(payload.message.headers.subject.as_deref()),
&payload.message.contents,
)
}
fn value_at_path<'a>(value: &'a Value, path: &[&str]) -> Option<&'a Value> {
let mut current = value;
for key in path {
current = current.get(*key)?;
}
Some(current)
}
fn string_from_value(value: &Value) -> Option<String> {
match value {
Value::String(s) => Some(s.clone()),
Value::Object(map) => map
.get("address")
.or_else(|| map.get("email"))
.or_else(|| map.get("value"))
.and_then(|v| v.as_str())
.map(ToOwned::to_owned),
_ => None,
}
}
fn recipients_from_value(value: Option<&Value>) -> Vec<Recipient> {
let Some(value) = value else {
return Vec::new();
};
match value {
Value::Array(values) => values
.iter()
.filter_map(|v| string_from_value(v).map(Recipient::Address))
.collect(),
_ => string_from_value(value)
.map(Recipient::Address)
.into_iter()
.collect(),
}
}
fn string_at_paths(value: &Value, paths: &[&[&str]]) -> Option<String> {
paths.iter().find_map(|path| {
value_at_path(value, path)
.and_then(|v| v.as_str())
.map(ToOwned::to_owned)
})
}
fn subject_from_headers_value(value: Option<&Value>) -> Option<String> {
let headers = value?.as_array()?;
headers.iter().find_map(|header| {
let parts = header.as_array()?;
let name = parts.first()?.as_str()?.trim();
if !name.eq_ignore_ascii_case("Subject") {
return None;
}
parts
.get(1)?
.as_str()
.map(|s| s.trim().to_string())
})
}
fn extract_qa_reply_from_value(payload: &Value, expected_domain: &str) -> Option<(i64, String)> {
if let Some(messages) = payload.get("messages").and_then(Value::as_array) {
for message in messages {
if let Some(reply) = extract_qa_reply_from_message(
&recipients_from_value(value_at_path(message, &["envelope", "to"])),
expected_domain,
string_at_paths(
message,
&[
&["message", "subject"],
&["message", "headers", "subject"],
&["headers", "subject"],
],
)
.or_else(|| subject_from_headers_value(value_at_path(message, &["message", "headers"])))
.as_deref(),
&string_at_paths(message, &[&["message", "contents"], &["contents"], &["raw_message"]])
.unwrap_or_default(),
) {
return Some(reply);
}
}
}
extract_qa_reply_from_message(
&recipients_from_value(value_at_path(payload, &["envelope", "to"])),
expected_domain,
string_at_paths(
payload,
&[
&["message", "subject"],
&["message", "headers", "subject"],
&["headers", "subject"],
],
)
.or_else(|| subject_from_headers_value(value_at_path(payload, &["message", "headers"])))
.as_deref(),
&string_at_paths(payload, &[&["message", "contents"], &["contents"], &["raw_message"]])
.unwrap_or_default(),
)
}
fn extract_qa_reply_from_message(
recipients: &[Recipient],
expected_domain: &str,
subject: Option<&str>,
contents: &str,
) -> Option<(i64, String)> {
let _qa_recipient = recipients.iter().find(|recipient| {
let address = recipient.address();
let Some((local, domain)) = address.rsplit_once('@') else {
return false;
};
local.eq_ignore_ascii_case("qa") && domain.eq_ignore_ascii_case(expected_domain)
})?;
let subject = subject.map(ToOwned::to_owned).or_else(|| {
contents
.replace("\r\n", "\n")
.lines()
.find_map(|line| line.strip_prefix("Subject: ").map(ToOwned::to_owned))
})?;
let id = email::extract_id_from_subject(&subject).ok()?;
let body = email::extract_plain_text_body(contents);
if body.is_empty() {
return None;
}
Some((id, body))
}
pub async fn webhook(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
body: String,
) -> Result<Json<MtaHookResponse>, (StatusCode, String)> {
if !webhook_secret_matches(&headers, &state.webhook_secret) {
eprintln!(
"Rejected webhook: invalid secret; expected_len={}; {}",
state.webhook_secret.len(),
webhook_secret_debug(&headers)
);
return Err((StatusCode::UNAUTHORIZED, "invalid secret".to_string()));
}
let payload_value: Value = match serde_json::from_str(&body) {
Ok(payload) => payload,
Err(err) => {
eprintln!("Rejected webhook: invalid JSON payload: {err}; body={body}");
return Ok(Json(MtaHookResponse { action: "accept" }));
}
};
let parsed_reply = serde_json::from_value::<MtaHookPayload>(payload_value.clone())
.ok()
.and_then(|payload| extract_qa_reply(&payload, &state.qa_reply_domain))
.or_else(|| extract_qa_reply_from_value(&payload_value, &state.qa_reply_domain));
if let Some((id, body)) = parsed_reply {
let db = state
.db
.lock()
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "db error".to_string()))?;
db.execute(
"UPDATE questions SET answer = ?1, answered_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
WHERE id = ?2 AND answer IS NULL",
rusqlite::params![body, id],
)
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "update error".to_string()))?;
eprintln!("Stored Q&A reply for question #{id}");
return Ok(Json(MtaHookResponse { action: "discard" }));
}
eprintln!("Q&A webhook accepted payload without matched reply: {payload_value}");
Ok(Json(MtaHookResponse { action: "accept" }))
}
#[cfg(test)]
mod tests {
use axum::http::HeaderMap;
use serde_json::Value;
use super::{extract_qa_reply, extract_qa_reply_from_value, webhook_secret_matches, MtaHookPayload};
#[test]
fn extracts_reply_from_current_stalwart_payload() {
let payload: MtaHookPayload = serde_json::from_str(
r#"{
"envelope": {
"to": [
{
"address": "qa@extremist.software"
}
]
},
"message": {
"subject": "Re: 42 - hello",
"contents": "This is the answer.\n\nOn earlier mail wrote:\n> quoted"
}
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply(&payload, "extremist.software"),
Some((42, "This is the answer.".to_string()))
);
}
#[test]
fn extracts_reply_from_legacy_batch_payload() {
let payload: MtaHookPayload = serde_json::from_str(
r#"{
"messages": [
{
"envelope": {
"to": ["qa@extremist.software"]
},
"message": {
"subject": "Re: 7 - legacy"
},
"contents": "Legacy answer"
}
]
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply(&payload, "extremist.software"),
Some((7, "Legacy answer".to_string()))
);
}
#[test]
fn accepts_header_secret() {
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", "topsecret".parse().unwrap());
assert!(webhook_secret_matches(&headers, "topsecret"));
}
#[test]
fn accepts_basic_auth_password() {
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
"Basic dXNlcjp0b3BzZWNyZXQ=".parse().unwrap(),
);
assert!(webhook_secret_matches(&headers, "topsecret"));
}
#[test]
fn extracts_reply_from_value_with_string_recipient() {
let payload: Value = serde_json::from_str(
r#"{
"envelope": {
"to": "qa@extremist.software"
},
"message": {
"headers": {
"subject": "Re: 9 - hi"
},
"contents": "Answer body"
}
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply_from_value(&payload, "extremist.software"),
Some((9, "Answer body".to_string()))
);
}
#[test]
fn extracts_reply_from_value_with_header_pairs() {
let payload: Value = serde_json::from_str(
r#"{
"envelope": {
"to": [{"address":"qa@extremist.software"}]
},
"message": {
"headers": [
["From", " jet@extremist.software\r\n"],
["Subject", " Re: 11 - hi\r\n"]
],
"contents": "Answer from header pairs"
}
}"#,
)
.unwrap();
assert_eq!(
extract_qa_reply_from_value(&payload, "extremist.software"),
Some((11, "Answer from header pairs".to_string()))
);
}
}