1use http::{HeaderValue, Method};
2use tower_http::cors::{AllowOrigin, Any, CorsLayer};
3
4#[derive(Debug, thiserror::Error)]
6pub enum CorsDomainError {
7 #[error("{domain} is an invalid header value")]
9 InvalidHeader {
10 domain: String,
12 },
13
14 #[error("wildcard origin (`*`) cannot be passed as part of a list: {input}")]
16 WildCardNotAllowed {
17 input: String,
19 },
20}
21
22pub(crate) fn create_cors_layer(http_cors_domains: &str) -> Result<CorsLayer, CorsDomainError> {
24 let cors = match http_cors_domains.trim() {
25 "*" => CorsLayer::new()
26 .allow_methods([Method::GET, Method::POST])
27 .allow_origin(Any)
28 .allow_headers(Any),
29 _ => {
30 let iter = http_cors_domains.split(',').map(str::trim);
31 if iter.clone().any(|o| o == "*") {
32 return Err(CorsDomainError::WildCardNotAllowed {
33 input: http_cors_domains.to_string(),
34 })
35 }
36
37 let origins = iter
38 .map(|domain| {
39 domain
40 .parse::<HeaderValue>()
41 .map_err(|_| CorsDomainError::InvalidHeader { domain: domain.to_string() })
42 })
43 .collect::<Result<Vec<HeaderValue>, _>>()?;
44
45 let origin = AllowOrigin::list(origins);
46 CorsLayer::new()
47 .allow_methods([Method::GET, Method::POST])
48 .allow_origin(origin)
49 .allow_headers(Any)
50 }
51 };
52 Ok(cors)
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58
59 #[test]
60 fn test_wildcard_with_spaces_rejected() {
61 let result = create_cors_layer("http://example.com, *");
63 assert!(matches!(result, Err(CorsDomainError::WildCardNotAllowed { .. })));
64 }
65}