feat: parse query parameters

This commit is contained in:
Lys 2023-10-08 21:38:36 +03:00
parent a7cad8493a
commit 281e98c2c8
Signed by: lyssieth
GPG key ID: C9CF3D614FAA3940
5 changed files with 2334 additions and 2 deletions

1982
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,3 +7,30 @@ publish = ["crates-io"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
color-eyre = "0.6.2"
poem = { version = "1.3.58", features = [
"compression",
"cookie",
"session",
"static-files",
"xml",
] }
poem-openapi = { version = "3.0.5", features = [
"time",
"openapi-explorer",
"url",
"static-files",
] }
quick-xml = { version = "0.30.0", features = ["serialize"] }
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107"
tokio = { version = "1.32.0", features = ["full"] }
tracing = { version = "0.1.37", features = ["async-await"] }
tracing-subscriber = { version = "0.3.17", features = [
"env-filter",
"tracing",
"parking_lot",
"time",
] }
url = { version = "2.4.1", features = ["serde"] }
url-escape = "0.1.1"

189
src/authentication.rs Normal file
View file

@ -0,0 +1,189 @@
use std::{collections::HashMap, str::FromStr, string::ToString};
use color_eyre::Report;
use poem::{http::StatusCode, Error, FromRequest, Request, RequestBody, Result};
mod de;
#[poem::async_trait]
impl<'a> FromRequest<'a> for Authentication {
async fn from_request(req: &'a Request, _: &mut RequestBody) -> Result<Self> {
let query = req.uri().query().unwrap_or_default();
if query.is_empty() {
return Err(Error::from_string("Empty query", StatusCode::BAD_REQUEST));
}
let query = url_escape::decode(query);
let query = query
.split('&')
.filter_map(|q| q.split_once('='))
.collect::<HashMap<_, _>>();
let user = {
let user = query.get("u").map(ToString::to_string);
if user.is_none() {
return Err(Error::from_string(
"Missing username",
StatusCode::BAD_REQUEST,
));
}
user.expect("Missing username")
};
let password = query.get("p").map(ToString::to_string);
if password.is_some() {
return Err(Error::from_string(
"Password authentication is not supported",
StatusCode::BAD_REQUEST,
));
}
let token = query.get("t").map(ToString::to_string);
let salt = query.get("s").map(ToString::to_string);
if token.is_none() || salt.is_none() {
return Err(Error::from_string(
"Missing token or salt",
StatusCode::BAD_REQUEST,
));
}
let token = token.expect("Missing token");
let salt = salt.expect("Missing salt");
let version = {
let version = query.get("v").map(ToString::to_string);
if version.is_none() {
return Err(Error::from_string(
"Missing version",
StatusCode::BAD_REQUEST,
));
}
version
.expect("Missing version")
.parse::<VersionTriple>()
.map_err(|e| {
Error::from_string(format!("Invalid version: {e}"), StatusCode::BAD_REQUEST)
})
}?;
if version < VersionTriple(1, 13, 0) {
return Err(Error::from_string(
"Unsupported version. We only support 1.13.0 and above",
StatusCode::BAD_REQUEST,
));
}
let client = {
let client = query.get("c").map(ToString::to_string);
if client.is_none() {
return Err(Error::from_string(
"Missing client",
StatusCode::BAD_REQUEST,
));
}
client.expect("Missing client")
};
let format = query
.get("f")
.map_or_else(|| "xml".to_string(), ToString::to_string);
Ok(Self {
username: user,
token,
salt,
version,
client,
format,
})
}
}
#[derive(Debug, Clone)]
pub struct Authentication {
pub username: String,
pub token: String,
pub salt: String,
pub version: VersionTriple,
pub client: String,
pub format: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VersionTriple(pub u32, pub u32, pub u32);
impl PartialOrd for VersionTriple {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for VersionTriple {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.cmp(&other.0)
.then(self.1.cmp(&other.1))
.then(self.2.cmp(&other.2))
}
}
impl FromStr for VersionTriple {
type Err = Report;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts = s.split('.').collect::<Vec<_>>();
if parts.len() != 3 {
return Err(Report::msg("Invalid version string"));
}
let major = parts[0].parse::<u32>()?;
let minor = parts[1].parse::<u32>()?;
let patch = parts[2].parse::<u32>()?;
Ok(Self(major, minor, patch))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_version_triple_from_str() {
assert_eq!(
VersionTriple::from_str("1.2.3").expect("Failed to parse version triple"),
VersionTriple(1, 2, 3)
);
}
#[test]
fn parse_version_triple_from_str_with_invalid_string() {
assert!(VersionTriple::from_str("1.2").is_err());
assert!(VersionTriple::from_str("").is_err());
}
#[test]
fn parse_version_triple_from_str_with_invalid_number() {
assert!(VersionTriple::from_str("1.2.a").is_err());
}
#[test]
fn serialize_version_triple_to_string() {
assert_eq!(
serde_json::to_string(&VersionTriple(1, 2, 3))
.expect("Failed to serialize version triple"),
"\"1.2.3\""
);
}
#[test]
fn deserialize_version_triple_from_string() {
assert_eq!(
serde_json::from_str::<VersionTriple>("\"1.2.3\"")
.expect("Failed to deserialize version triple"),
VersionTriple(1, 2, 3)
);
}
}

57
src/authentication/de.rs Normal file
View file

@ -0,0 +1,57 @@
use std::str::FromStr;
use serde::{
de::{Deserialize, Visitor},
Deserializer, Serialize,
};
use super::VersionTriple;
impl<'de> Deserialize<'de> for VersionTriple {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(VersionTripleVisitor)
}
}
struct VersionTripleVisitor;
impl Visitor<'_> for VersionTripleVisitor {
type Value = VersionTriple;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a version string `major.minor.patch`")
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
VersionTriple::from_str(&v).map_err(serde::de::Error::custom)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
VersionTriple::from_str(v).map_err(serde::de::Error::custom)
}
fn visit_borrowed_str<E>(self, v: &'_ str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
VersionTriple::from_str(v).map_err(serde::de::Error::custom)
}
}
impl Serialize for VersionTriple {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
format!("{}.{}.{}", self.0, self.1, self.2).serialize(serializer)
}
}

View file

@ -1,3 +1,80 @@
fn main() { #![warn(clippy::pedantic, clippy::nursery)]
println!("Hello, world!"); #![deny(clippy::unwrap_used, clippy::panic)]
use std::time::Duration;
use authentication::Authentication;
use color_eyre::Result;
use poem::{
get,
listener::TcpListener,
middleware,
web::{CompressionAlgo, CompressionLevel},
EndpointExt, Route,
};
use tracing::info;
use tracing_subscriber::{fmt, EnvFilter};
mod authentication;
const LISTEN: &str = "0.0.0.0:1234";
#[tokio::main]
async fn main() -> Result<()> {
color_eyre::install()?;
install_tracing()?;
let listener = TcpListener::bind(LISTEN);
let route = Route::new()
.at("/", get(hello_world))
.at("/auth", get(auth_test))
.with(middleware::CatchPanic::new())
.with(
middleware::Compression::new()
.algorithms([CompressionAlgo::BR, CompressionAlgo::GZIP])
.with_quality(CompressionLevel::Default),
)
.with(middleware::Tracing)
.with(middleware::CookieJarManager::new())
.with(middleware::NormalizePath::new(
middleware::TrailingSlash::Trim,
));
info!("Listening on http://{LISTEN}");
let signal_waiter = || async {
let _ = tokio::signal::ctrl_c().await;
};
let server = poem::Server::new(listener)
.name("rave")
.run_with_graceful_shutdown(route, signal_waiter(), Some(Duration::from_secs(5)));
server.await?;
Ok(())
}
fn install_tracing() -> Result<()> {
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "warn,rave=debug".to_string());
fmt()
.pretty()
.with_env_filter(EnvFilter::from(filter))
.try_init()
.map_err(|v| color_eyre::eyre::eyre!("failed to install tracing: {v}"))?;
Ok(())
}
#[poem::handler]
const fn hello_world() -> &'static str {
"Hello, world!"
}
#[allow(clippy::needless_pass_by_value)]
#[poem::handler]
fn auth_test(auth: Authentication) -> String {
format!("{auth:?}")
} }