diff --git a/src/env.rs b/src/env.rs index 564fdd69..743fb75c 100644 --- a/src/env.rs +++ b/src/env.rs @@ -34,9 +34,9 @@ pub struct Environment { /// Optional directive to translate collected keys into a form that matches what serializers /// that the configuration would expect. For example if you have the `kebab-case` attribute - /// for your serde config types, you may want to pass `Case::Kebab` here. + /// for your serde config types, you may want to pass `ConversionStrategy::All(Case::Kebab)` here. #[cfg(feature = "convert-case")] - convert_case: Option, + convert_case: Option, /// Optional character sequence that separates each env value into a vector. only works when `try_parsing` is set to true /// Once set, you cannot have type String on the same environment, unless you set `list_parse_keys`. @@ -90,6 +90,19 @@ pub struct Environment { source: Option>, } +/// Strategy to translate collected keys into a form that matches what serializers +/// that the configuration would expect. +#[cfg(feature = "convert-case")] +#[derive(Clone, Debug)] +enum ConversionStrategy { + /// Apply the conversion to all collected keys + All(Case), + /// Exclude the specified keys from conversion + Exclude(Case, Vec), + /// Only convert the specified keys + Only(Case, Vec), +} + impl Environment { /// Optional prefix that will limit access to the environment to only keys that /// begin with the defined prefix. @@ -118,7 +131,33 @@ impl Environment { #[cfg(feature = "convert-case")] pub fn convert_case(mut self, tt: Case) -> Self { - self.convert_case = Some(tt); + self.convert_case = Some(ConversionStrategy::All(tt)); + self + } + + #[cfg(feature = "convert-case")] + pub fn convert_case_exclude_keys( + mut self, + tt: Case, + keys: impl IntoIterator>, + ) -> Self { + self.convert_case = Some(ConversionStrategy::Exclude( + tt, + keys.into_iter().map(|k| k.into()).collect(), + )); + self + } + + #[cfg(feature = "convert-case")] + pub fn convert_case_for_keys( + mut self, + tt: Case, + keys: impl IntoIterator>, + ) -> Self { + self.convert_case = Some(ConversionStrategy::Only( + tt, + keys.into_iter().map(|k| k.into()).collect(), + )); self } @@ -270,8 +309,20 @@ impl Source for Environment { } #[cfg(feature = "convert-case")] - if let Some(convert_case) = convert_case { - key = key.to_case(*convert_case); + if let Some(strategy) = convert_case { + match strategy { + ConversionStrategy::All(convert_case) => key = key.to_case(*convert_case), + ConversionStrategy::Exclude(convert_case, keys) => { + if !keys.contains(&key) { + key = key.to_case(*convert_case); + } + } + ConversionStrategy::Only(convert_case, keys) => { + if keys.contains(&key) { + key = key.to_case(*convert_case); + } + } + } } let value = if self.try_parsing { diff --git a/tests/testsuite/env.rs b/tests/testsuite/env.rs index 00a59e8e..448c8e25 100644 --- a/tests/testsuite/env.rs +++ b/tests/testsuite/env.rs @@ -574,6 +574,60 @@ fn test_parse_nested_kebab() { ); } +#[test] +#[cfg(feature = "convert-case")] +fn test_parse_kebab_case_with_exclude_keys() { + use config::Case; + #[derive(Deserialize, Debug)] + struct TestConfig { + value_a: String, + #[serde(rename = "value-b")] + value_b: String, + } + + temp_env::with_vars( + vec![("VALUE_A", Some("value1")), ("VALUE_B", Some("value2"))], + || { + let environment = + Environment::default().convert_case_exclude_keys(Case::Kebab, ["value_a"]); + + let config = Config::builder().add_source(environment).build().unwrap(); + + let config: TestConfig = config.try_deserialize().unwrap(); + + assert_eq!(config.value_a, "value1"); + assert_eq!(config.value_b, "value2"); + }, + ); +} + +#[test] +#[cfg(feature = "convert-case")] +fn test_parse_kebab_case_for_keys() { + use config::Case; + #[derive(Deserialize, Debug)] + struct TestConfig { + value_a: String, + #[serde(rename = "value-b")] + value_b: String, + } + + temp_env::with_vars( + vec![("VALUE_A", Some("value1")), ("VALUE_B", Some("value2"))], + || { + let environment = + Environment::default().convert_case_for_keys(Case::Kebab, ["value_b"]); + + let config = Config::builder().add_source(environment).build().unwrap(); + + let config: TestConfig = config.try_deserialize().unwrap(); + + assert_eq!(config.value_a, "value1"); + assert_eq!(config.value_b, "value2"); + }, + ); +} + #[test] fn test_parse_string() { // using a struct in an enum here to make serde use `deserialize_any`