diff --git a/src/app/app_tests.rs b/src/app/app_tests.rs index f5a4ff5..6be41e6 100644 --- a/src/app/app_tests.rs +++ b/src/app/app_tests.rs @@ -4,10 +4,14 @@ mod tests { use anyhow::anyhow; use pretty_assertions::{assert_eq, assert_str_eq}; use rstest::rstest; + use serde::de::value::StringDeserializer; + use serde::de::IntoDeserializer; use tokio::sync::mpsc; use crate::app::context_clues::{build_context_clue_string, SERVARR_CONTEXT_CLUES}; - use crate::app::{App, AppConfig, Data, ServarrConfig}; + use crate::app::{ + deserialize_env_var, interpolate_env_vars, App, AppConfig, Data, ServarrConfig, + }; use crate::models::servarr_data::radarr::radarr_data::{ActiveRadarrBlock, RadarrData}; use crate::models::servarr_data::sonarr::sonarr_data::{ActiveSonarrBlock, SonarrData}; use crate::models::{HorizontallyScrollableText, TabRoute}; @@ -348,6 +352,155 @@ mod tests { assert_eq!(servarr_config.ssl_cert_path, None); } + #[test] + fn test_deserialize_env_var() { + std::env::set_var("TEST_VAR_DESERIALIZE", "testing"); + let deserializer: StringDeserializer = + "${TEST_VAR_DESERIALIZE}".to_owned().into_deserializer(); + + let env_var: Result = deserialize_env_var(deserializer); + + assert!(env_var.is_ok()); + assert_str_eq!(env_var.unwrap(), "testing"); + std::env::remove_var("TEST_VAR_DESERIALIZE"); + } + + #[test] + fn test_deserialize_optional_env_var_is_present() { + std::env::set_var("TEST_VAR_DESERIALIZE_OPTION", "localhost"); + let yaml_data = r#" + host: ${TEST_VAR_DESERIALIZE_OPTION} + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.host, Some("localhost".to_string())); + std::env::remove_var("TEST_VAR_DESERIALIZE_OPTION"); + } + + #[test] + fn test_deserialize_optional_env_var_does_not_overwrite_non_env_value() { + std::env::set_var("TEST_VAR_DESERIALIZE_OPTION_NO_OVERWRITE", "localhost"); + let yaml_data = r#" + host: www.example.com + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.host, Some("www.example.com".to_string())); + std::env::remove_var("TEST_VAR_DESERIALIZE_OPTION_NO_OVERWRITE"); + } + + #[test] + fn test_deserialize_optional_env_var_empty() { + let yaml_data = r#" + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.port, None); + } + + #[test] + fn test_deserialize_optional_u16_env_var_is_present() { + std::env::set_var("TEST_VAR_DESERIALIZE_OPTION_U16", "1"); + let yaml_data = r#" + port: ${TEST_VAR_DESERIALIZE_OPTION_U16} + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.port, Some(1)); + std::env::remove_var("TEST_VAR_DESERIALIZE_OPTION_U16"); + } + + #[test] + fn test_deserialize_optional_u16_env_var_does_not_overwrite_non_env_value() { + std::env::set_var("TEST_VAR_DESERIALIZE_OPTION_U16_UNUSED", "1"); + let yaml_data = r#" + port: 1234 + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.port, Some(1234)); + std::env::remove_var("TEST_VAR_DESERIALIZE_OPTION_U16_UNUSED"); + } + + #[test] + fn test_deserialize_optional_u16_env_var_invalid_number() { + let yaml_data = r#" + port: "hi" + api_token: "test123" + "#; + let result: Result = serde_yaml::from_str(yaml_data); + + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("invalid digit found in string")); + } + + #[test] + fn test_deserialize_optional_u16_env_var_empty() { + let yaml_data = r#" + api_token: "test123" + "#; + + let config: ServarrConfig = serde_yaml::from_str(yaml_data).unwrap(); + + assert_eq!(config.port, None); + } + + #[test] + fn test_interpolate_env_vars() { + std::env::set_var("TEST_VAR_INTERPOLATION", "testing"); + + let var = interpolate_env_vars("${TEST_VAR_INTERPOLATION}"); + + assert_str_eq!(var, "testing"); + std::env::remove_var("TEST_VAR_INTERPOLATION"); + } + + #[test] + fn test_interpolate_env_vars_defaults_to_original_string_if_not_in_yaml_interpolation_format() { + let var = interpolate_env_vars("TEST_VAR_INTERPOLATION_NON_YAML"); + + assert_str_eq!(var, "TEST_VAR_INTERPOLATION_NON_YAML"); + } + + #[test] + fn test_interpolate_env_vars_scrubs_all_unnecessary_characters() { + std::env::set_var( + "TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS", + r#""" + `"'https://dontdo:this@testing.com/query?test=%20query#results'"` {([\|$!])} + """#, + ); + + let var = interpolate_env_vars("${TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS}"); + + assert_str_eq!( + var, + "https://dontdo:this@testing.com/query?test=%20query#results" + ); + std::env::remove_var("TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS"); + } + + #[test] + fn test_interpolate_env_vars_scrubs_all_unnecessary_characters_from_non_environment_variable() { + let var = interpolate_env_vars("https://dontdo:this@testing.com/query?test=%20query#results"); + + assert_str_eq!( + var, + "https://dontdo:this@testing.com/query?test=%20query#results" + ); + } + #[test] fn test_servarr_config_redacted_debug() { let host = "localhost".to_owned(); diff --git a/src/app/mod.rs b/src/app/mod.rs index bb33a21..744b5be 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -3,6 +3,7 @@ use std::process; use anyhow::{anyhow, Error}; use colored::Colorize; use log::{debug, error}; +use regex::Regex; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Sender; use tokio_util::sync::CancellationToken; @@ -261,11 +262,16 @@ impl AppConfig { #[derive(Redact, Deserialize, Serialize, Clone)] pub struct ServarrConfig { + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub host: Option, + #[serde(default, deserialize_with = "deserialize_u16_env_var")] pub port: Option, + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub uri: Option, + #[serde(default, deserialize_with = "deserialize_env_var")] #[redact] pub api_token: String, + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub ssl_cert_path: Option, } @@ -294,3 +300,61 @@ pub fn log_and_print_error(error: String) { error!("{}", error); eprintln!("error: {}", error.red()); } + +fn deserialize_env_var<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s: String = String::deserialize(deserializer)?; + let interpolated = interpolate_env_vars(&s); + Ok(interpolated) +} + +fn deserialize_optional_env_var<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + match s { + Some(value) => { + let interpolated = interpolate_env_vars(&value); + Ok(Some(interpolated)) + } + None => Ok(None), + } +} + +fn deserialize_u16_env_var<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + match s { + Some(value) => { + let interpolated = interpolate_env_vars(&value); + interpolated + .parse::() + .map(Some) + .map_err(serde::de::Error::custom) + } + None => Ok(None), + } +} + +fn interpolate_env_vars(s: &str) -> String { + let result = s.to_string(); + let scrubbing_regex = Regex::new(r#"[\s\{\}!\$^\(\)\[\]\\\|`'"]+"#).unwrap(); + let var_regex = Regex::new(r"\$\{(.*?)\}").unwrap(); + + var_regex + .replace_all(s, |caps: ®ex::Captures<'_>| { + if let Some(mat) = caps.get(1) { + if let Ok(value) = std::env::var(mat.as_str()) { + return scrubbing_regex.replace_all(&value, "").to_string(); + } + } + + scrubbing_regex.replace_all(&result, "").to_string() + }) + .to_string() +}