diff --git a/src/app/app_tests.rs b/src/app/app_tests.rs index 59cfb8e..e7a5084 100644 --- a/src/app/app_tests.rs +++ b/src/app/app_tests.rs @@ -2,12 +2,16 @@ mod tests { use crate::models::Route; use anyhow::anyhow; - use pretty_assertions::assert_eq; + use pretty_assertions::{assert_eq, assert_str_eq}; use rstest::rstest; + use serde::de::value::StringDeserializer; + use serde::de::IntoDeserializer; use tokio::sync::mpsc; use crate::app::context_clues::{build_context_clue_string, SERVARR_CONTEXT_CLUES}; - use crate::app::{App, AppConfig, Data, ServarrConfig}; + use crate::app::{ + deserialize_env_var, interpolate_env_vars, App, AppConfig, Data, ServarrConfig, + }; use crate::models::servarr_data::radarr::radarr_data::{ActiveRadarrBlock, RadarrData}; use crate::models::servarr_data::sonarr::sonarr_data::{ActiveSonarrBlock, SonarrData}; use crate::models::{HorizontallyScrollableText, TabRoute}; @@ -347,4 +351,62 @@ mod tests { assert!(servarr_config.api_token.is_empty()); assert_eq!(servarr_config.ssl_cert_path, None); } + + #[test] + fn test_deserialize_env_var() { + std::env::set_var("TEST_VAR_DESERIALIZE", "testing"); + let deserializer: StringDeserializer = + "${TEST_VAR_DESERIALIZE}".to_owned().into_deserializer(); + + let env_var: Result = deserialize_env_var(deserializer); + + assert!(env_var.is_ok()); + assert_str_eq!(env_var.unwrap(), "testing"); + std::env::remove_var("TEST_VAR_DESERIALIZE"); + } + + #[test] + fn test_interpolate_env_vars() { + std::env::set_var("TEST_VAR_INTERPOLATION", "testing"); + + let var = interpolate_env_vars("${TEST_VAR_INTERPOLATION}"); + + assert_str_eq!(var, "testing"); + std::env::remove_var("TEST_VAR_INTERPOLATION"); + } + + #[test] + fn test_interpolate_env_vars_defaults_to_original_string_if_not_in_yaml_interpolation_format() { + let var = interpolate_env_vars("TEST_VAR_INTERPOLATION"); + + assert_str_eq!(var, "TEST_VAR_INTERPOLATION"); + } + + #[test] + fn test_interpolate_env_vars_scrubs_all_unnecessary_characters() { + std::env::set_var( + "TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS", + r#""" + `"'https://dontdo:this@testing.com/query?test=%20query#results'"` {([\|$!])} + """#, + ); + + let var = interpolate_env_vars("${TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS}"); + + assert_str_eq!( + var, + "https://dontdo:this@testing.com/query?test=%20query#results" + ); + std::env::remove_var("TEST_VAR_INTERPOLATION_UNNECESSARY_CHARACTERS"); + } + + #[test] + fn test_interpolate_env_vars_scrubs_all_unnecessary_characters_from_non_environment_variable() { + let var = interpolate_env_vars("https://dontdo:this@testing.com/query?test=%20query#results"); + + assert_str_eq!( + var, + "https://dontdo:this@testing.com/query?test=%20query#results" + ); + } } diff --git a/src/app/mod.rs b/src/app/mod.rs index 41aa3e5..760bc7e 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -3,6 +3,7 @@ use std::process; use anyhow::{anyhow, Error}; use colored::Colorize; use log::{debug, error}; +use regex::Regex; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Sender; use tokio_util::sync::CancellationToken; @@ -213,71 +214,6 @@ pub struct Data<'a> { pub sonarr_data: SonarrData<'a>, } -fn deserialize_env_var<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let s: String = String::deserialize(deserializer)?; - let interpolated = interpolate_env_vars(&s); - Ok(interpolated) -} - -fn deserialize_optional_env_var<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let s: Option = Option::deserialize(deserializer)?; - match s { - Some(value) => { - let interpolated = interpolate_env_vars(&value); - Ok(Some(interpolated)) - } - None => Ok(None), - } -} - -fn deserialize_u16_env_var<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let s: Option = Option::deserialize(deserializer)?; - match s { - Some(value) => { - let interpolated = interpolate_env_vars(&value); - interpolated - .parse::() - .map(Some) - .map_err(serde::de::Error::custom) - } - None => Ok(None), - } -} - -fn interpolate_env_vars(s: &str) -> String { - let mut result = s.to_string(); - let start = "${"; - let end = "}"; - - while let Some(start_index) = result.find(start) { - let end_index = result[start_index..] - .find(end) - .map(|i| start_index + i + end.len()); - if let Some(end_index) = end_index { - let var_name = &result[start_index + start.len()..end_index - end.len()]; - if let Ok(value) = std::env::var(var_name) { - // Match found; interpolate - result.replace_range(start_index..end_index, &value); - } else { - break; // No var match found; interpret it literally - } - } else { - break; // No closing brace found - } - } - - result -} - #[derive(Debug, Deserialize, Serialize, Default, Clone)] pub struct AppConfig { pub radarr: Option, @@ -325,15 +261,15 @@ impl AppConfig { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct ServarrConfig { - #[serde(deserialize_with = "deserialize_optional_env_var")] + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub host: Option, - #[serde(deserialize_with = "deserialize_u16_env_var")] + #[serde(default, deserialize_with = "deserialize_u16_env_var")] pub port: Option, - #[serde(deserialize_with = "deserialize_optional_env_var")] + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub uri: Option, - #[serde(deserialize_with = "deserialize_env_var")] + #[serde(default, deserialize_with = "deserialize_env_var")] pub api_token: String, - #[serde(deserialize_with = "deserialize_optional_env_var")] + #[serde(default, deserialize_with = "deserialize_optional_env_var")] pub ssl_cert_path: Option, } @@ -362,3 +298,61 @@ pub fn log_and_print_error(error: String) { error!("{}", error); eprintln!("error: {}", error.red()); } + +fn deserialize_env_var<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let s: String = String::deserialize(deserializer)?; + let interpolated = interpolate_env_vars(&s); + Ok(interpolated) +} + +fn deserialize_optional_env_var<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + match s { + Some(value) => { + let interpolated = interpolate_env_vars(&value); + Ok(Some(interpolated)) + } + None => Ok(None), + } +} + +fn deserialize_u16_env_var<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let s: Option = Option::deserialize(deserializer)?; + match s { + Some(value) => { + let interpolated = interpolate_env_vars(&value); + interpolated + .parse::() + .map(Some) + .map_err(serde::de::Error::custom) + } + None => Ok(None), + } +} + +fn interpolate_env_vars(s: &str) -> String { + let result = s.to_string(); + let scrubbing_regex = Regex::new(r#"[\s\{\}!\$^\(\)\[\]\\\|`'"]+"#).unwrap(); + let var_regex = Regex::new(r"\$\{(.*?)\}").unwrap(); + + var_regex + .replace_all(s, |caps: ®ex::Captures<'_>| { + if let Some(mat) = caps.get(1) { + if let Ok(value) = std::env::var(mat.as_str()) { + return scrubbing_regex.replace_all(&value, "").to_string(); + } + } + + scrubbing_regex.replace_all(&result, "").to_string() + }) + .to_string() +}