From a51a73170435df3128997c850dde7ded2de16e12 Mon Sep 17 00:00:00 2001 From: Justin Tracey Date: Sun, 6 Oct 2024 05:48:08 -0400 Subject: [PATCH] join: add support for multibyte separators (#6736) * join: add test for multibyte separators * join: implement support for multibyte separators * join: use a trait instead of an enum for separator * join: test whitespace merging --- src/uu/join/src/join.rs | 291 +++++++++++++----- tests/by-util/test_join.rs | 45 ++- tests/fixtures/join/contiguous_separators.txt | 1 + tests/fixtures/join/multibyte_sep.expected | 1 + tests/fixtures/join/multibyte_sep_1.txt | 1 + tests/fixtures/join/multibyte_sep_2.txt | 1 + 6 files changed, 251 insertions(+), 89 deletions(-) create mode 100644 tests/fixtures/join/contiguous_separators.txt create mode 100644 tests/fixtures/join/multibyte_sep.expected create mode 100644 tests/fixtures/join/multibyte_sep_1.txt create mode 100644 tests/fixtures/join/multibyte_sep_2.txt diff --git a/src/uu/join/src/join.rs b/src/uu/join/src/join.rs index 9eecaca3fe6..e4f3cdba224 100644 --- a/src/uu/join/src/join.rs +++ b/src/uu/join/src/join.rs @@ -3,11 +3,11 @@ // For the full copyright and license information, please view the LICENSE // file that was distributed with this source code. -// spell-checker:ignore (ToDO) autoformat FILENUM whitespaces pairable unpairable nocheck +// spell-checker:ignore (ToDO) autoformat FILENUM whitespaces pairable unpairable nocheck memmem use clap::builder::ValueParser; use clap::{crate_version, Arg, ArgAction, Command}; -use memchr::{memchr3_iter, memchr_iter}; +use memchr::{memchr_iter, memmem::Finder, Memchr3}; use std::cmp::Ordering; use std::error::Error; use std::ffi::OsString; @@ -60,13 +60,114 @@ enum FileNum { File2, } -#[derive(Copy, Clone, PartialEq)] -enum Sep { - Char(u8), +#[derive(Clone)] +enum SepSetting { + /// Any single-byte separator. + Byte(u8), + /// A single character more than one byte long. + Char(Vec), + /// No separators, join on the entire line. Line, + /// Whitespace separators. Whitespaces, } +trait Separator: Clone { + /// Using this separator, return the start and end index of all fields in the haystack. + fn field_ranges(&self, haystack: &[u8], len_guess: usize) -> Vec<(usize, usize)>; + /// The separator as it appears when in the output. + fn output_separator(&self) -> &[u8]; +} + +/// Simple separators one byte in length. +#[derive(Copy, Clone)] +struct OneByteSep { + byte: [u8; 1], +} + +impl Separator for OneByteSep { + fn field_ranges(&self, haystack: &[u8], len_guess: usize) -> Vec<(usize, usize)> { + let mut field_ranges = Vec::with_capacity(len_guess); + let mut last_end = 0; + + for i in memchr_iter(self.byte[0], haystack) { + field_ranges.push((last_end, i)); + last_end = i + 1; + } + field_ranges.push((last_end, haystack.len())); + field_ranges + } + + fn output_separator(&self) -> &[u8] { + &self.byte + } +} + +/// Multi-byte (but still single character) separators. +#[derive(Clone)] +struct MultiByteSep<'a> { + finder: Finder<'a>, +} + +impl<'a> Separator for MultiByteSep<'a> { + fn field_ranges(&self, haystack: &[u8], len_guess: usize) -> Vec<(usize, usize)> { + let mut field_ranges = Vec::with_capacity(len_guess); + let mut last_end = 0; + + for i in self.finder.find_iter(haystack) { + field_ranges.push((last_end, i)); + last_end = i + self.finder.needle().len(); + } + field_ranges.push((last_end, haystack.len())); + field_ranges + } + + fn output_separator(&self) -> &[u8] { + self.finder.needle() + } +} + +/// Whole-line separator. +#[derive(Copy, Clone)] +struct LineSep {} + +impl Separator for LineSep { + fn field_ranges(&self, haystack: &[u8], _len_guess: usize) -> Vec<(usize, usize)> { + vec![(0, haystack.len())] + } + + fn output_separator(&self) -> &[u8] { + &[] + } +} + +/// Default whitespace separator. +#[derive(Copy, Clone)] +struct WhitespaceSep {} + +impl Separator for WhitespaceSep { + fn field_ranges(&self, haystack: &[u8], len_guess: usize) -> Vec<(usize, usize)> { + let mut field_ranges = Vec::with_capacity(len_guess); + let mut last_end = 0; + + // GNU join used Bourne shell field splitters by default + // FIXME: but now uses locale-dependent whitespace + for i in Memchr3::new(b' ', b'\t', b'\n', haystack) { + // leading whitespace should be dropped, contiguous whitespace merged + if i > last_end { + field_ranges.push((last_end, i)); + } + last_end = i + 1; + } + field_ranges.push((last_end, haystack.len())); + field_ranges + } + + fn output_separator(&self) -> &[u8] { + b" " + } +} + #[derive(Copy, Clone, PartialEq)] enum CheckOrder { Default, @@ -82,7 +183,7 @@ struct Settings { print_joined: bool, ignore_case: bool, line_ending: LineEnding, - separator: Sep, + separator: SepSetting, autoformat: bool, format: Vec, empty: Vec, @@ -100,7 +201,7 @@ impl Default for Settings { print_joined: true, ignore_case: false, line_ending: LineEnding::Newline, - separator: Sep::Whitespaces, + separator: SepSetting::Whitespaces, autoformat: false, format: vec![], empty: vec![], @@ -111,15 +212,15 @@ impl Default for Settings { } /// Output representation. -struct Repr<'a> { +struct Repr<'a, Sep: Separator> { line_ending: LineEnding, - separator: u8, - format: &'a [Spec], + separator: Sep, + format: Vec, empty: &'a [u8], } -impl<'a> Repr<'a> { - fn new(line_ending: LineEnding, separator: u8, format: &'a [Spec], empty: &'a [u8]) -> Self { +impl<'a, Sep: Separator> Repr<'a, Sep> { + fn new(line_ending: LineEnding, separator: Sep, format: Vec, empty: &'a [u8]) -> Self { Repr { line_ending, separator, @@ -155,7 +256,7 @@ impl<'a> Repr<'a> { ) -> Result<(), std::io::Error> { for i in 0..line.field_ranges.len() { if i != index { - writer.write_all(&[self.separator])?; + writer.write_all(self.separator.output_separator())?; writer.write_all(line.get_field(i).unwrap())?; } } @@ -169,7 +270,7 @@ impl<'a> Repr<'a> { { for i in 0..self.format.len() { if i > 0 { - writer.write_all(&[self.separator])?; + writer.write_all(self.separator.output_separator())?; } let field = match f(&self.format[i]) { @@ -188,13 +289,13 @@ impl<'a> Repr<'a> { } /// Input processing parameters. -struct Input { +struct Input { separator: Sep, ignore_case: bool, check_order: CheckOrder, } -impl Input { +impl Input { fn new(separator: Sep, ignore_case: bool, check_order: CheckOrder) -> Self { Self { separator, @@ -271,24 +372,8 @@ struct Line { } impl Line { - fn new(string: Vec, separator: Sep, len_guess: usize) -> Self { - let mut field_ranges = Vec::with_capacity(len_guess); - let mut last_end = 0; - if separator == Sep::Whitespaces { - // GNU join uses Bourne shell field splitters by default - for i in memchr3_iter(b' ', b'\t', b'\n', &string) { - if i > last_end { - field_ranges.push((last_end, i)); - } - last_end = i + 1; - } - } else if let Sep::Char(sep) = separator { - for i in memchr_iter(sep, &string) { - field_ranges.push((last_end, i)); - last_end = i + 1; - } - } - field_ranges.push((last_end, string.len())); + fn new(string: Vec, separator: &Sep, len_guess: usize) -> Self { + let field_ranges = separator.field_ranges(&string, len_guess); Self { field_ranges, @@ -351,7 +436,12 @@ impl<'a> State<'a> { } /// Skip the current unpaired line. - fn skip_line(&mut self, writer: &mut impl Write, input: &Input, repr: &Repr) -> UResult<()> { + fn skip_line( + &mut self, + writer: &mut impl Write, + input: &Input, + repr: &Repr<'a, Sep>, + ) -> UResult<()> { if self.print_unpaired { self.print_first_line(writer, repr)?; } @@ -362,7 +452,7 @@ impl<'a> State<'a> { /// Keep reading line sequence until the key does not change, return /// the first line whose key differs. - fn extend(&mut self, input: &Input) -> UResult> { + fn extend(&mut self, input: &Input) -> UResult> { while let Some(line) = self.next_line(input)? { let diff = input.compare(self.get_current_key(), line.get_field(self.key)); @@ -377,11 +467,11 @@ impl<'a> State<'a> { } /// Print lines in the buffers as headers. - fn print_headers( + fn print_headers( &self, writer: &mut impl Write, other: &State, - repr: &Repr, + repr: &Repr<'a, Sep>, ) -> Result<(), std::io::Error> { if self.has_line() { if other.has_line() { @@ -397,11 +487,11 @@ impl<'a> State<'a> { } /// Combine two line sequences. - fn combine( + fn combine( &self, writer: &mut impl Write, other: &State, - repr: &Repr, + repr: &Repr<'a, Sep>, ) -> Result<(), std::io::Error> { let key = self.get_current_key(); @@ -444,13 +534,16 @@ impl<'a> State<'a> { } } - fn reset_read_line(&mut self, input: &Input) -> Result<(), std::io::Error> { - let line = self.read_line(input.separator)?; + fn reset_read_line( + &mut self, + input: &Input, + ) -> Result<(), std::io::Error> { + let line = self.read_line(&input.separator)?; self.reset(line); Ok(()) } - fn reset_next_line(&mut self, input: &Input) -> Result<(), JoinError> { + fn reset_next_line(&mut self, input: &Input) -> Result<(), JoinError> { let line = self.next_line(input)?; self.reset(line); Ok(()) @@ -460,7 +553,7 @@ impl<'a> State<'a> { !self.seq.is_empty() } - fn initialize(&mut self, read_sep: Sep, autoformat: bool) -> usize { + fn initialize(&mut self, read_sep: &Sep, autoformat: bool) -> usize { if let Some(line) = crash_if_err!(1, self.read_line(read_sep)) { self.seq.push(line); @@ -471,7 +564,12 @@ impl<'a> State<'a> { 0 } - fn finalize(&mut self, writer: &mut impl Write, input: &Input, repr: &Repr) -> UResult<()> { + fn finalize( + &mut self, + writer: &mut impl Write, + input: &Input, + repr: &Repr<'a, Sep>, + ) -> UResult<()> { if self.has_line() { if self.print_unpaired { self.print_first_line(writer, repr)?; @@ -491,7 +589,7 @@ impl<'a> State<'a> { } /// Get the next line without the order check. - fn read_line(&mut self, sep: Sep) -> Result, std::io::Error> { + fn read_line(&mut self, sep: &Sep) -> Result, std::io::Error> { match self.lines.next() { Some(value) => { self.line_num += 1; @@ -506,8 +604,8 @@ impl<'a> State<'a> { } /// Get the next line with the order check. - fn next_line(&mut self, input: &Input) -> Result, JoinError> { - if let Some(line) = self.read_line(input.separator)? { + fn next_line(&mut self, input: &Input) -> Result, JoinError> { + if let Some(line) = self.read_line(&input.separator)? { if input.check_order == CheckOrder::Disabled { return Ok(Some(line)); } @@ -543,11 +641,11 @@ impl<'a> State<'a> { self.seq[0].get_field(self.key) } - fn print_line( + fn print_line( &self, writer: &mut impl Write, line: &Line, - repr: &Repr, + repr: &Repr<'a, Sep>, ) -> Result<(), std::io::Error> { if repr.uses_format() { repr.print_format(writer, |spec| match *spec { @@ -568,31 +666,53 @@ impl<'a> State<'a> { repr.print_line_ending(writer) } - fn print_first_line(&self, writer: &mut impl Write, repr: &Repr) -> Result<(), std::io::Error> { + fn print_first_line( + &self, + writer: &mut impl Write, + repr: &Repr<'a, Sep>, + ) -> Result<(), std::io::Error> { self.print_line(writer, &self.seq[0], repr) } } -fn parse_separator(value_os: &OsString) -> UResult { +fn parse_separator(value_os: &OsString) -> UResult { + // Five possible separator values: + // No argument supplied, separate on whitespace; handled implicitly as the default elsewhere + // An empty string arg, whole line separation + // On unix-likes only, a single arbitrary byte + // The two-character "\0" string, interpreted as a single 0 byte + // A single scalar valid in the locale encoding (currently only UTF-8) + + if value_os.is_empty() { + return Ok(SepSetting::Line); + } + #[cfg(unix)] - let value = value_os.as_bytes(); - #[cfg(not(unix))] - let value = match value_os.to_str() { - Some(value) => value.as_bytes(), - None => { - return Err(USimpleError::new( - 1, - "unprintable field separators are only supported on unix-like platforms", - )); + { + let value = value_os.as_bytes(); + if value.len() == 1 { + return Ok(SepSetting::Byte(value[0])); } + } + + let Some(value) = value_os.to_str() else { + #[cfg(unix)] + return Err(USimpleError::new(1, "non-UTF-8 multi-byte tab")); + #[cfg(not(unix))] + return Err(USimpleError::new( + 1, + "unprintable field separators are only supported on unix-like platforms", + )); }; - match value.len() { - 0 => Ok(Sep::Line), - 1 => Ok(Sep::Char(value[0])), - 2 if value[0] == b'\\' && value[1] == b'0' => Ok(Sep::Char(0)), + + let mut chars = value.chars(); + let c = chars.next().expect("valid string with at least one byte"); + match chars.next() { + None => Ok(SepSetting::Char(value.into())), + Some('0') if c == '\\' => Ok(SepSetting::Byte(0)), _ => Err(USimpleError::new( 1, - format!("multi-character tab {}", value_os.to_string_lossy()), + format!("multi-character tab {}", value), )), } } @@ -695,7 +815,20 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> { return Err(USimpleError::new(1, "both files cannot be standard input")); } - exec(file1, file2, settings) + let sep = settings.separator.clone(); + match sep { + SepSetting::Byte(byte) => exec(file1, file2, settings, OneByteSep { byte: [byte] }), + SepSetting::Char(c) => exec( + file1, + file2, + settings, + MultiByteSep { + finder: Finder::new(&c), + }, + ), + SepSetting::Whitespaces => exec(file1, file2, settings, WhitespaceSep {}), + SepSetting::Line => exec(file1, file2, settings, LineSep {}), + } } pub fn uu_app() -> Command { @@ -816,7 +949,7 @@ FILENUM is 1 or 2, corresponding to FILE1 or FILE2", ) } -fn exec(file1: &str, file2: &str, settings: Settings) -> UResult<()> { +fn exec(file1: &str, file2: &str, settings: Settings, sep: Sep) -> UResult<()> { let stdin = stdin(); let mut state1 = State::new( @@ -837,16 +970,12 @@ fn exec(file1: &str, file2: &str, settings: Settings) -> UResult<()> { settings.print_unpaired2, )?; - let input = Input::new( - settings.separator, - settings.ignore_case, - settings.check_order, - ); + let input = Input::new(sep.clone(), settings.ignore_case, settings.check_order); let format = if settings.autoformat { let mut format = vec![Spec::Key]; let mut initialize = |state: &mut State| { - let max_fields = state.initialize(settings.separator, settings.autoformat); + let max_fields = state.initialize(&sep, settings.autoformat); for i in 0..max_fields { if i != state.key { format.push(Spec::Field(state.file_num, i)); @@ -857,20 +986,12 @@ fn exec(file1: &str, file2: &str, settings: Settings) -> UResult<()> { initialize(&mut state2); format } else { - state1.initialize(settings.separator, settings.autoformat); - state2.initialize(settings.separator, settings.autoformat); + state1.initialize(&sep, settings.autoformat); + state2.initialize(&sep, settings.autoformat); settings.format }; - let repr = Repr::new( - settings.line_ending, - match settings.separator { - Sep::Char(sep) => sep, - _ => b' ', - }, - &format, - &settings.empty, - ); + let repr = Repr::new(settings.line_ending, sep, format, &settings.empty); let stdout = stdout(); let mut writer = BufWriter::new(stdout.lock()); diff --git a/tests/by-util/test_join.rs b/tests/by-util/test_join.rs index 955f21d68b0..6516f386a79 100644 --- a/tests/by-util/test_join.rs +++ b/tests/by-util/test_join.rs @@ -58,6 +58,25 @@ fn default_arguments() { .stdout_only_fixture("default.expected"); } +#[test] +fn only_whitespace_separators_merge() { + new_ucmd!() + .arg("contiguous_separators.txt") + .arg("-") + .pipe_in(" a ,c ") + .succeeds() + .stdout_only("a ,,,b ,c \n"); + + new_ucmd!() + .arg("contiguous_separators.txt") + .arg("-t") + .arg(",") + .arg("-") + .pipe_in(" a ,c ") + .succeeds() + .stdout_only(" a ,,,b,c \n"); +} + #[test] fn different_fields() { new_ucmd!() @@ -208,9 +227,9 @@ fn tab_multi_character() { .arg("semicolon_fields_1.txt") .arg("semicolon_fields_2.txt") .arg("-t") - .arg("э") + .arg("ab") .fails() - .stderr_is("join: multi-character tab э\n"); + .stderr_is("join: multi-character tab ab\n"); } #[test] @@ -437,14 +456,22 @@ fn non_unicode() { #[cfg(unix)] { - let invalid_utf8: u8 = 167; + let non_utf8_byte: u8 = 167; new_ucmd!() .arg("-t") - .arg(OsStr::from_bytes(&[invalid_utf8])) + .arg(OsStr::from_bytes(&[non_utf8_byte])) .arg("non-unicode_1.bin") .arg("non-unicode_2.bin") .succeeds() .stdout_only_fixture("non-unicode_sep.expected"); + + new_ucmd!() + .arg("-t") + .arg(OsStr::from_bytes(&[non_utf8_byte, non_utf8_byte])) + .arg("non-unicode_1.bin") + .arg("non-unicode_2.bin") + .fails() + .stderr_is("join: non-UTF-8 multi-byte tab\n"); } #[cfg(windows)] @@ -462,6 +489,16 @@ fn non_unicode() { } } +#[test] +fn multibyte_sep() { + new_ucmd!() + .arg("-t§") + .arg("multibyte_sep_1.txt") + .arg("multibyte_sep_2.txt") + .succeeds() + .stdout_only_fixture("multibyte_sep.expected"); +} + #[test] fn null_field_separators() { new_ucmd!() diff --git a/tests/fixtures/join/contiguous_separators.txt b/tests/fixtures/join/contiguous_separators.txt new file mode 100644 index 00000000000..bfc290229c4 --- /dev/null +++ b/tests/fixtures/join/contiguous_separators.txt @@ -0,0 +1 @@ + a ,,,b diff --git a/tests/fixtures/join/multibyte_sep.expected b/tests/fixtures/join/multibyte_sep.expected new file mode 100644 index 00000000000..c8a6aacfd18 --- /dev/null +++ b/tests/fixtures/join/multibyte_sep.expected @@ -0,0 +1 @@ +a§b§c diff --git a/tests/fixtures/join/multibyte_sep_1.txt b/tests/fixtures/join/multibyte_sep_1.txt new file mode 100644 index 00000000000..a42f9eefae1 --- /dev/null +++ b/tests/fixtures/join/multibyte_sep_1.txt @@ -0,0 +1 @@ +a§b diff --git a/tests/fixtures/join/multibyte_sep_2.txt b/tests/fixtures/join/multibyte_sep_2.txt new file mode 100644 index 00000000000..81288ad520b --- /dev/null +++ b/tests/fixtures/join/multibyte_sep_2.txt @@ -0,0 +1 @@ +a§c