use std::{collections::HashMap, str::FromStr, string::ToString}; use color_eyre::Report; use poem::{http::StatusCode, Error, FromRequest, Request, RequestBody, Result}; mod de; #[poem::async_trait] impl<'a> FromRequest<'a> for Authentication { async fn from_request(req: &'a Request, _: &mut RequestBody) -> Result { let query = req.uri().query().unwrap_or_default(); if query.is_empty() { return Err(Error::from_string("Empty query", StatusCode::BAD_REQUEST)); } let query = url_escape::decode(query); let query = query .split('&') .filter_map(|q| q.split_once('=')) .collect::>(); let user = { let user = query.get("u").map(ToString::to_string); if user.is_none() { return Err(Error::from_string( "Missing username", StatusCode::BAD_REQUEST, )); } user.expect("Missing username") }; let password = query.get("p").map(ToString::to_string); if password.is_some() { return Err(Error::from_string( "Password authentication is not supported", StatusCode::BAD_REQUEST, )); } let token = query.get("t").map(ToString::to_string); let salt = query.get("s").map(ToString::to_string); if token.is_none() || salt.is_none() { return Err(Error::from_string( "Missing token or salt", StatusCode::BAD_REQUEST, )); } let token = token.expect("Missing token"); let salt = salt.expect("Missing salt"); let version = { let version = query.get("v").map(ToString::to_string); if version.is_none() { return Err(Error::from_string( "Missing version", StatusCode::BAD_REQUEST, )); } version .expect("Missing version") .parse::() .map_err(|e| { Error::from_string(format!("Invalid version: {e}"), StatusCode::BAD_REQUEST) }) }?; if version < VersionTriple(1, 13, 0) { return Err(Error::from_string( "Unsupported version. We only support 1.13.0 and above", StatusCode::BAD_REQUEST, )); } let client = { let client = query.get("c").map(ToString::to_string); if client.is_none() { return Err(Error::from_string( "Missing client", StatusCode::BAD_REQUEST, )); } client.expect("Missing client") }; let format = query .get("f") .map_or_else(|| "xml".to_string(), ToString::to_string); Ok(Self { username: user, token, salt, version, client, format, }) } } #[derive(Debug, Clone)] pub struct Authentication { pub username: String, pub token: String, pub salt: String, pub version: VersionTriple, pub client: String, pub format: String, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct VersionTriple(pub u32, pub u32, pub u32); impl PartialOrd for VersionTriple { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for VersionTriple { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.0 .cmp(&other.0) .then(self.1.cmp(&other.1)) .then(self.2.cmp(&other.2)) } } impl FromStr for VersionTriple { type Err = Report; fn from_str(s: &str) -> Result { let parts = s.split('.').collect::>(); if parts.len() != 3 { return Err(Report::msg("Invalid version string")); } let major = parts[0].parse::()?; let minor = parts[1].parse::()?; let patch = parts[2].parse::()?; Ok(Self(major, minor, patch)) } } #[cfg(test)] mod tests { use super::*; #[test] fn parse_version_triple_from_str() { assert_eq!( VersionTriple::from_str("1.2.3").expect("Failed to parse version triple"), VersionTriple(1, 2, 3) ); } #[test] fn parse_version_triple_from_str_with_invalid_string() { assert!(VersionTriple::from_str("1.2").is_err()); assert!(VersionTriple::from_str("").is_err()); } #[test] fn parse_version_triple_from_str_with_invalid_number() { assert!(VersionTriple::from_str("1.2.a").is_err()); } #[test] fn serialize_version_triple_to_string() { assert_eq!( serde_json::to_string(&VersionTriple(1, 2, 3)) .expect("Failed to serialize version triple"), "\"1.2.3\"" ); } #[test] fn deserialize_version_triple_from_string() { assert_eq!( serde_json::from_str::("\"1.2.3\"") .expect("Failed to deserialize version triple"), VersionTriple(1, 2, 3) ); } }