diff --git a/README.md b/README.md index 1b39e74..6716b09 100644 --- a/README.md +++ b/README.md @@ -357,6 +357,9 @@ whisparr: port: 6969 api_token: someApiToken1234567890 ssl_cert_path: /path/to/whisparr.crt + custom_headers: # Example of adding custom headers to all requests to the Servarr instance + traefik-auth-bypass-key: someBypassKey1234567890 + SOME-OTHER-CUSTOM-HEADER: ${MY_CUSTOM_HEADER_VALUE} bazarr: - host: 192.168.0.67 port: 6767 diff --git a/src/app/app_tests.rs b/src/app/app_tests.rs index 99697b8..bda0636 100644 --- a/src/app/app_tests.rs +++ b/src/app/app_tests.rs @@ -2,6 +2,8 @@ mod tests { use anyhow::anyhow; use pretty_assertions::{assert_eq, assert_str_eq}; + use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + use serde_json::Value; use serial_test::serial; use tokio::sync::mpsc; @@ -340,6 +342,43 @@ mod tests { assert_eq!(servarr_config.api_token, Some(String::new())); assert_eq!(servarr_config.api_token_file, None); assert_eq!(servarr_config.ssl_cert_path, None); + assert_eq!(servarr_config.custom_headers, None); + } + + #[test] + fn serialize_header_map_basic() { + let mut header_map = HeaderMap::new(); + header_map.insert( + HeaderName::from_static("x-api-key"), + HeaderValue::from_static("abc123"), + ); + header_map.insert( + HeaderName::from_static("header-1"), + HeaderValue::from_static("test"), + ); + + let config = ServarrConfig { + custom_headers: Some(header_map), + ..ServarrConfig::default() + }; + + let v: Value = serde_json::to_value(&config).expect("serialize ok"); + let custom = v.get("custom_headers").unwrap(); + assert!(custom.is_object()); + let obj = custom.as_object().unwrap(); + + assert_eq!(obj.get("x-api-key").unwrap(), "abc123"); + assert_eq!(obj.get("header-1").unwrap(), "test"); + + assert!(obj.get("X-Api-Key").is_none()); + assert!(obj.get("HEADER-1").is_none()); + } + + #[test] + fn serialize_header_map_none_is_null() { + let config = ServarrConfig::default(); + let v: Value = serde_json::to_value(&config).expect("serialize ok"); + assert!(v.get("custom_headers").unwrap().is_null()); } #[test] @@ -383,6 +422,66 @@ mod tests { assert_eq!(config.port, None); } + #[test] + #[serial] + fn test_deserialize_optional_env_var_header_map_is_present() { + unsafe { std::env::set_var("TEST_VAR_DESERIALIZE_HEADER_OPTION", "localhost") }; + let expected_custom_headers = { + let mut headers = HeaderMap::new(); + headers.insert("X-Api-Host", "localhost".parse().unwrap()); + headers.insert("api-token", "test123".parse().unwrap()); + headers + }; + let yaml_data = r#" + custom_headers: + X-Api-Host: ${TEST_VAR_DESERIALIZE_HEADER_OPTION} + api-token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.custom_headers, Some(expected_custom_headers)); + unsafe { std::env::remove_var("TEST_VAR_DESERIALIZE_HEADER_OPTION") }; + } + + #[test] + #[serial] + fn test_deserialize_optional_env_var_header_map_does_not_overwrite_non_env_value() { + unsafe { + std::env::set_var( + "TEST_VAR_DESERIALIZE_HEADER_OPTION_NO_OVERWRITE", + "localhost", + ) + }; + let expected_custom_headers = { + let mut headers = HeaderMap::new(); + headers.insert("X-Api-Host", "www.example.com".parse().unwrap()); + headers.insert("api-token", "test123".parse().unwrap()); + headers + }; + let yaml_data = r#" + custom_headers: + X-Api-Host: www.example.com + api-token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.custom_headers, Some(expected_custom_headers)); + unsafe { std::env::remove_var("TEST_VAR_DESERIALIZE_HEADER_OPTION_NO_OVERWRITE") }; + } + + #[test] + fn test_deserialize_optional_env_var_header_map_empty() { + let yaml_data = r#" + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.custom_headers, None); + } + #[test] #[serial] fn test_deserialize_optional_u16_env_var_is_present() { @@ -496,7 +595,9 @@ mod tests { let api_token = "thisisatest".to_owned(); let api_token_file = "/root/.config/api_token".to_owned(); let ssl_cert_path = "/some/path".to_owned(); - let expected_str = format!("ServarrConfig {{ name: Some(\"{name}\"), host: Some(\"{host}\"), port: Some({port}), uri: Some(\"{uri}\"), weight: Some({weight}), api_token: Some(\"***********\"), api_token_file: Some(\"{api_token_file}\"), ssl_cert_path: Some(\"{ssl_cert_path}\") }}"); + let mut custom_headers = HeaderMap::new(); + custom_headers.insert("X-Custom-Header", "value".parse().unwrap()); + let expected_str = format!("ServarrConfig {{ name: Some(\"{name}\"), host: Some(\"{host}\"), port: Some({port}), uri: Some(\"{uri}\"), weight: Some({weight}), api_token: Some(\"***********\"), api_token_file: Some(\"{api_token_file}\"), ssl_cert_path: Some(\"{ssl_cert_path}\"), custom_headers: Some({{\"x-custom-header\": \"value\"}}) }}"); let servarr_config = ServarrConfig { name: Some(name), host: Some(host), @@ -506,6 +607,7 @@ mod tests { api_token: Some(api_token), api_token_file: Some(api_token_file), ssl_cert_path: Some(ssl_cert_path), + custom_headers: Some(custom_headers), }; assert_str_eq!(format!("{servarr_config:?}"), expected_str); diff --git a/src/app/mod.rs b/src/app/mod.rs index 5707f76..7ff3a48 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -3,7 +3,9 @@ use colored::Colorize; use itertools::Itertools; use log::{debug, error}; use regex::Regex; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::path::PathBuf; use std::{fs, process}; use tokio::sync::mpsc::Sender; @@ -335,6 +337,12 @@ pub struct ServarrConfig { pub api_token_file: Option, #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub ssl_cert_path: Option, + #[serde( + default, + deserialize_with = "deserialize_optional_env_var_header_map", + serialize_with = "serialize_header_map" + )] + pub custom_headers: Option, } impl ServarrConfig { @@ -380,6 +388,7 @@ impl Default for ServarrConfig { api_token: Some(String::new()), api_token_file: None, ssl_cert_path: None, + custom_headers: None, } } } @@ -389,6 +398,27 @@ pub fn log_and_print_error(error: String) { eprintln!("error: {}", error.red()); } +fn serialize_header_map(headers: &Option, serializer: S) -> Result +where + S: serde::Serializer, +{ + if let Some(headers) = headers { + let mut map = HashMap::new(); + for (name, value) in headers.iter() { + let name_str = name.as_str().to_string(); + let value_str = value + .to_str() + .map_err(serde::ser::Error::custom)? + .to_string(); + + map.insert(name_str, value_str); + } + map.serialize(serializer) + } else { + serializer.serialize_none() + } +} + fn deserialize_optional_env_var<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, @@ -403,6 +433,28 @@ where } } +fn deserialize_optional_env_var_header_map<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let opt: Option> = Option::deserialize(deserializer)?; + match opt { + Some(map) => { + let mut header_map = HeaderMap::new(); + for (k, v) in map.iter() { + let name = HeaderName::from_bytes(k.as_bytes()).map_err(serde::de::Error::custom)?; + let value_str = interpolate_env_vars(v); + let value = HeaderValue::from_str(&value_str).map_err(serde::de::Error::custom)?; + header_map.insert(name, value); + } + Ok(Some(header_map)) + } + None => Ok(None), + } +} + fn deserialize_u16_env_var<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, diff --git a/src/network/mod.rs b/src/network/mod.rs index 164a976..8f28b1f 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -19,6 +19,7 @@ use crate::models::Serdeable; use crate::network::radarr_network::RadarrEvent; #[cfg(test)] use mockall::automock; +use reqwest::header::HeaderMap; pub mod radarr_network; pub mod sonarr_network; @@ -167,28 +168,36 @@ impl<'a, 'b> Network<'a, 'b> { method, body, api_token, + custom_headers, .. } = request_props; debug!("Creating RequestBuilder for resource: {uri:?}"); debug!("Sending {method:?} request to {uri} with body {body:?}"); match method { - RequestMethod::Get => self.client.get(uri).header("X-Api-Key", api_token), + RequestMethod::Get => self + .client + .get(uri) + .header("X-Api-Key", api_token) + .headers(custom_headers), RequestMethod::Post => self .client .post(uri) .json(&body.unwrap_or_default()) - .header("X-Api-Key", api_token), + .header("X-Api-Key", api_token) + .headers(custom_headers), RequestMethod::Put => self .client .put(uri) .json(&body.unwrap_or_default()) - .header("X-Api-Key", api_token), + .header("X-Api-Key", api_token) + .headers(custom_headers), RequestMethod::Delete => self .client .delete(uri) .json(&body.unwrap_or_default()) - .header("X-Api-Key", api_token), + .header("X-Api-Key", api_token) + .headers(custom_headers), } } @@ -212,6 +221,7 @@ impl<'a, 'b> Network<'a, 'b> { uri, api_token, ssl_cert_path, + custom_headers: custom_headers_option, .. } = app .server_tabs @@ -245,12 +255,15 @@ impl<'a, 'b> Network<'a, 'b> { uri = format!("{uri}?{params}"); } + let custom_headers = custom_headers_option.clone().unwrap_or_default(); + RequestProps { uri, method, body, api_token: api_token.as_ref().expect("API token not found").clone(), ignore_status_code: false, + custom_headers, } } } @@ -270,4 +283,5 @@ pub struct RequestProps { pub body: Option, pub api_token: String, pub ignore_status_code: bool, + pub custom_headers: HeaderMap, } diff --git a/src/network/network_tests.rs b/src/network/network_tests.rs index 58cef91..6329fb1 100644 --- a/src/network/network_tests.rs +++ b/src/network/network_tests.rs @@ -6,6 +6,7 @@ mod tests { use mockito::{Mock, Server, ServerGuard}; use pretty_assertions::assert_str_eq; + use reqwest::header::HeaderMap; use reqwest::Client; use rstest::rstest; use serde::{Deserialize, Serialize}; @@ -81,6 +82,7 @@ mod tests { }), api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |_, _| (), ) @@ -105,6 +107,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |response, mut app| app.error = HorizontallyScrollableText::from(response.value), ) @@ -138,6 +141,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: true, + custom_headers: HeaderMap::new(), }, |response, _app| test_result = response.value, ) @@ -176,6 +180,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |_, _| (), ) @@ -229,6 +234,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |response, mut app| app.error = HorizontallyScrollableText::from(response.value), ) @@ -261,6 +267,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |response, mut app| app.error = HorizontallyScrollableText::from(response.value), ) @@ -301,6 +308,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |response, mut app| app.error = HorizontallyScrollableText::from(response.value), ) @@ -331,6 +339,7 @@ mod tests { body: None, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers: HeaderMap::new(), }, |response, mut app| app.error = HorizontallyScrollableText::from(response.value), ) @@ -363,8 +372,11 @@ mod tests { let mut async_server = server .mock(&request_method.to_string().to_uppercase(), "/test") .match_header("X-Api-Key", "test1234") + .match_header("X-Custom-Header", "CustomValue") .with_status(200); let mut body = None::; + let mut custom_headers = HeaderMap::new(); + custom_headers.insert("X-Custom-Header", "CustomValue".parse().unwrap()); if request_method == RequestMethod::Post { async_server = async_server.with_body( @@ -388,6 +400,7 @@ mod tests { body, api_token: "test1234".to_owned(), ignore_status_code: false, + custom_headers, }) .await .send() @@ -440,6 +453,7 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert!(request_props.api_token.is_empty()); + assert!(request_props.custom_headers.is_empty()); } #[rstest] @@ -476,6 +490,47 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert_str_eq!(request_props.api_token, api_token); + assert!(request_props.custom_headers.is_empty()); + } + + #[rstest] + #[tokio::test] + async fn test_request_props_from_custom_config_custom_headers( + #[values(RadarrEvent::GetMovies, SonarrEvent::ListSeries)] network_event: impl Into + + NetworkResource, + ) { + let api_token = "testToken1234".to_owned(); + let app_arc = Arc::new(Mutex::new(App::test_default())); + let resource = network_event.resource(); + let mut header_map = HeaderMap::new(); + header_map.insert("X-Custom-Header", "CustomValue".parse().unwrap()); + let servarr_config = ServarrConfig { + host: Some("192.168.0.123".to_owned()), + port: Some(8080), + api_token: Some(api_token.clone()), + ssl_cert_path: Some("/test/cert.crt".to_owned()), + custom_headers: Some(header_map.clone()), + ..ServarrConfig::default() + }; + { + let mut app = app_arc.lock().await; + app.server_tabs.tabs[0].config = Some(servarr_config.clone()); + app.server_tabs.tabs[1].config = Some(servarr_config); + } + let network = Network::new(&app_arc, CancellationToken::new(), Client::new()); + + let request_props = network + .request_props_from(network_event, RequestMethod::Get, None::<()>, None, None) + .await; + + assert_str_eq!( + request_props.uri, + format!("https://192.168.0.123:8080/api/v3{resource}") + ); + assert_eq!(request_props.method, RequestMethod::Get); + assert_eq!(request_props.body, None); + assert_str_eq!(request_props.api_token, api_token); + assert_eq!(request_props.custom_headers, header_map); } #[rstest] @@ -510,6 +565,7 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert_str_eq!(request_props.api_token, api_token); + assert!(request_props.custom_headers.is_empty()); } #[rstest] @@ -546,6 +602,7 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert!(request_props.api_token.is_empty()); + assert!(request_props.custom_headers.is_empty()); } #[rstest] @@ -588,6 +645,7 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert_str_eq!(request_props.api_token, api_token); + assert!(request_props.custom_headers.is_empty()); } #[rstest] @@ -628,6 +686,7 @@ mod tests { assert_eq!(request_props.method, RequestMethod::Get); assert_eq!(request_props.body, None); assert_str_eq!(request_props.api_token, api_token); + assert!(request_props.custom_headers.is_empty()); } #[derive(Clone, Serialize, Deserialize, Debug, Default, PartialEq, Eq)]