rave/app/src/authentication.rs

234 lines
6.9 KiB
Rust

use std::{collections::HashMap, fmt::Display, str::FromStr, string::ToString};
use color_eyre::Report;
use poem::{Error, FromRequest, IntoResponse, Request, RequestBody, Result};
use tracing::trace;
use crate::subsonic::{self, SubsonicResponse};
mod de;
#[poem::async_trait]
impl<'a> FromRequest<'a> for Authentication {
async fn from_request(req: &'a Request, _: &mut RequestBody) -> Result<Self> {
let query = req.uri().query().unwrap_or_default();
if query.is_empty() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::RequiredParameterMissing(Some(
"please provide a `u` parameter".to_string(),
)))
.into_response(),
));
}
let query = url_escape::decode(query);
let query = query
.split('&')
.filter_map(|q| q.split_once('='))
.collect::<HashMap<_, _>>();
trace!("Query: {query:?}");
let user = {
let user = query.get("u").map(ToString::to_string);
if user.is_none() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::RequiredParameterMissing(Some(
"please provide a `u` parameter".to_string(),
)))
.into_response(),
));
}
user.expect("Missing username")
};
trace!("User: {user}");
let password = query.get("p").map(ToString::to_string);
if password.is_some() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::Generic(Some(
"password authentication is not supported".to_string(),
)))
.into_response(),
));
}
trace!("Password: {password:?}");
let token = query.get("t").map(ToString::to_string);
let salt = query.get("s").map(ToString::to_string);
if token.is_none() || salt.is_none() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::RequiredParameterMissing(Some(
"please provide both `t` and `s` parameters".to_string(),
)))
.into_response(),
));
}
let token = token.expect("Missing token");
trace!("Token: {token}");
let salt = salt.expect("Missing salt");
trace!("Salt: {salt}");
let version = {
let version = query.get("v").map(ToString::to_string);
if version.is_none() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::RequiredParameterMissing(Some(
"please provide a `v` parameter".to_string(),
)))
.into_response(),
));
}
version
.expect("Missing version")
.parse::<VersionTriple>()
.map_err(|e| {
Error::from_response(
SubsonicResponse::new_error(subsonic::Error::Generic(Some(format!(
"invalid version parameter: {e}"
))))
.into_response(),
)
})
}?;
trace!("Version: {version}");
let client = {
let client = query.get("c").map(ToString::to_string);
if client.is_none() {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::RequiredParameterMissing(Some(
"please provide a `c` parameter".to_string(),
)))
.into_response(),
));
}
client.expect("Missing client")
};
trace!("Client: {client}");
let format = query
.get("f")
.map_or_else(|| "xml".to_string(), ToString::to_string);
if format != "xml" {
return Err(Error::from_response(
SubsonicResponse::new_error(subsonic::Error::Generic(Some(
"only xml format is supported".to_string(),
)))
.into_response(),
));
}
trace!("Format: {format}");
Ok(Self {
username: user,
token,
salt,
version,
client,
format,
})
}
}
#[derive(Debug, Clone)]
pub struct Authentication {
pub username: String,
pub token: String,
pub salt: String,
pub version: VersionTriple,
pub client: String,
pub format: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct VersionTriple(pub u32, pub u32, pub u32);
impl Display for VersionTriple {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.0, self.1, self.2)
}
}
impl PartialOrd for VersionTriple {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for VersionTriple {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.cmp(&other.0)
.then(self.1.cmp(&other.1))
.then(self.2.cmp(&other.2))
}
}
impl FromStr for VersionTriple {
type Err = Report;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts = s.split('.').collect::<Vec<_>>();
if parts.len() != 3 {
return Err(Report::msg("Invalid version string"));
}
let major = parts[0].parse::<u32>()?;
let minor = parts[1].parse::<u32>()?;
let patch = parts[2].parse::<u32>()?;
Ok(Self(major, minor, patch))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_version_triple_from_str() {
assert_eq!(
VersionTriple::from_str("1.2.3").expect("Failed to parse version triple"),
VersionTriple(1, 2, 3)
);
}
#[test]
fn parse_version_triple_from_str_with_invalid_string() {
assert!(VersionTriple::from_str("1.2").is_err());
assert!(VersionTriple::from_str("").is_err());
}
#[test]
fn parse_version_triple_from_str_with_invalid_number() {
assert!(VersionTriple::from_str("1.2.a").is_err());
}
#[test]
fn serialize_version_triple_to_string() {
assert_eq!(
serde_json::to_string(&VersionTriple(1, 2, 3))
.expect("Failed to serialize version triple"),
"\"1.2.3\""
);
}
#[test]
fn deserialize_version_triple_from_string() {
assert_eq!(
serde_json::from_str::<VersionTriple>("\"1.2.3\"")
.expect("Failed to deserialize version triple"),
VersionTriple(1, 2, 3)
);
}
}