Skip to content

Commit

Permalink
Add FP16 support (#257)
Browse files Browse the repository at this point in the history
  • Loading branch information
micahcc authored Feb 5, 2025
1 parent d42e6be commit 3bfb43e
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 6 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: ["1.61.0", stable, beta, nightly]
rust: ["1.70.0", stable, beta, nightly]
steps:
- uses: actions/checkout@v2

- uses: dtolnay/rust-toolchain@nightly
if: ${{ matrix.rust == '1.61.0' }}
if: ${{ matrix.rust == '1.70.0' }}
- name: Generate Cargo.lock with minimal-version dependencies
if: ${{ matrix.rust == '1.61.0' }}
if: ${{ matrix.rust == '1.70.0' }}
run: cargo -Zminimal-versions generate-lockfile

- uses: dtolnay/rust-toolchain@v1
Expand All @@ -29,7 +29,7 @@ jobs:
- name: build
run: cargo build -v
- name: test
if: ${{ matrix.rust != '1.61.0' }}
if: ${{ matrix.rust != '1.70.0' }}
run: cargo test -v && cargo doc -v

rustfmt:
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ categories = ["multimedia::images", "multimedia::encoding"]
exclude = ["tests/images/*", "tests/fuzz_images/*"]

[dependencies]
half = { version = "2.4.1" }
weezl = "0.1.0"
jpeg = { package = "jpeg-decoder", version = "0.3.0", default-features = false }
flate2 = "1.0.20"
Expand Down
3 changes: 3 additions & 0 deletions src/bytecast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//! TODO: Would like to use std-lib here.
use std::{mem, slice};

use half::f16;

macro_rules! integral_slice_as_bytes{($int:ty, $const:ident $(,$mut:ident)*) => {
pub(crate) fn $const(slice: &[$int]) -> &[u8] {
assert!(mem::align_of::<$int>() <= mem::size_of::<$int>());
Expand All @@ -31,4 +33,5 @@ integral_slice_as_bytes!(i32, i32_as_ne_bytes, i32_as_ne_mut_bytes);
integral_slice_as_bytes!(u64, u64_as_ne_bytes, u64_as_ne_mut_bytes);
integral_slice_as_bytes!(i64, i64_as_ne_bytes, i64_as_ne_mut_bytes);
integral_slice_as_bytes!(f32, f32_as_ne_bytes, f32_as_ne_mut_bytes);
integral_slice_as_bytes!(f16, f16_as_ne_bytes, f16_as_ne_mut_bytes);
integral_slice_as_bytes!(f64, f64_as_ne_bytes, f64_as_ne_mut_bytes);
8 changes: 6 additions & 2 deletions src/decoder/image.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::ifd::{Directory, Value};
use super::stream::{ByteOrder, DeflateReader, LZWReader, PackBitsReader};
use super::tag_reader::TagReader;
use super::{predict_f32, predict_f64, Limits};
use super::{predict_f16, predict_f32, predict_f64, Limits};
use super::{stream::SmartReader, ChunkType};
use crate::tags::{
CompressionMethod, PhotometricInterpretation, PlanarConfiguration, Predictor, SampleFormat, Tag,
Expand Down Expand Up @@ -592,7 +592,10 @@ impl Image {

// Validate that the predictor is supported for the sample type.
match (self.predictor, self.sample_format) {
(Predictor::Horizontal, SampleFormat::Int | SampleFormat::Uint) => {}
(
Predictor::Horizontal,
SampleFormat::Int | SampleFormat::Uint | SampleFormat::IEEEFP,
) => {}
(Predictor::Horizontal, _) => {
return Err(TiffError::UnsupportedError(
TiffUnsupportedError::HorizontalPredictor(color_type),
Expand Down Expand Up @@ -672,6 +675,7 @@ impl Image {

let row = &mut row[..data_row_bytes];
match color_type.bit_depth() {
16 => predict_f16(&mut encoded, row, samples),
32 => predict_f32(&mut encoded, row, samples),
64 => predict_f64(&mut encoded, row, samples),
_ => unreachable!(),
Expand Down
30 changes: 30 additions & 0 deletions src/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::tags::{
use crate::{
bytecast, ColorType, TiffError, TiffFormatError, TiffResult, TiffUnsupportedError, UsageError,
};
use half::f16;

use self::ifd::Directory;
use self::image::Image;
Expand All @@ -29,6 +30,8 @@ pub enum DecodingResult {
U32(Vec<u32>),
/// A vector of 64 bit unsigned ints
U64(Vec<u64>),
/// A vector of 16 bit IEEE floats (held in u16)
F16(Vec<f16>),
/// A vector of 32 bit IEEE floats
F32(Vec<f32>),
/// A vector of 64 bit IEEE floats
Expand Down Expand Up @@ -92,6 +95,14 @@ impl DecodingResult {
}
}

fn new_f16(size: usize, limits: &Limits) -> TiffResult<DecodingResult> {
if size > limits.decoding_buffer_size / std::mem::size_of::<u16>() {
Err(TiffError::LimitsExceeded)
} else {
Ok(DecodingResult::F16(vec![f16::ZERO; size]))
}
}

fn new_i8(size: usize, limits: &Limits) -> TiffResult<DecodingResult> {
if size > limits.decoding_buffer_size / std::mem::size_of::<i8>() {
Err(TiffError::LimitsExceeded)
Expand Down Expand Up @@ -130,6 +141,7 @@ impl DecodingResult {
DecodingResult::U16(ref mut buf) => DecodingBuffer::U16(&mut buf[start..]),
DecodingResult::U32(ref mut buf) => DecodingBuffer::U32(&mut buf[start..]),
DecodingResult::U64(ref mut buf) => DecodingBuffer::U64(&mut buf[start..]),
DecodingResult::F16(ref mut buf) => DecodingBuffer::F16(&mut buf[start..]),
DecodingResult::F32(ref mut buf) => DecodingBuffer::F32(&mut buf[start..]),
DecodingResult::F64(ref mut buf) => DecodingBuffer::F64(&mut buf[start..]),
DecodingResult::I8(ref mut buf) => DecodingBuffer::I8(&mut buf[start..]),
Expand All @@ -150,6 +162,8 @@ pub enum DecodingBuffer<'a> {
U32(&'a mut [u32]),
/// A slice of 64 bit unsigned ints
U64(&'a mut [u64]),
/// A slice of 16 bit IEEE floats
F16(&'a mut [f16]),
/// A slice of 32 bit IEEE floats
F32(&'a mut [f32]),
/// A slice of 64 bit IEEE floats
Expand All @@ -175,6 +189,7 @@ impl<'a> DecodingBuffer<'a> {
DecodingBuffer::I32(buf) => bytecast::i32_as_ne_mut_bytes(buf),
DecodingBuffer::U64(buf) => bytecast::u64_as_ne_mut_bytes(buf),
DecodingBuffer::I64(buf) => bytecast::i64_as_ne_mut_bytes(buf),
DecodingBuffer::F16(buf) => bytecast::f16_as_ne_mut_bytes(buf),
DecodingBuffer::F32(buf) => bytecast::f32_as_ne_mut_bytes(buf),
DecodingBuffer::F64(buf) => bytecast::f64_as_ne_mut_bytes(buf),
}
Expand Down Expand Up @@ -303,6 +318,19 @@ fn predict_f32(input: &mut [u8], output: &mut [u8], samples: usize) {
}
}

fn predict_f16(input: &mut [u8], output: &mut [u8], samples: usize) {
for i in samples..input.len() {
input[i] = input[i].wrapping_add(input[i - samples]);
}

for (i, chunk) in output.chunks_mut(2).enumerate() {
chunk.copy_from_slice(&u16::to_ne_bytes(u16::from_be_bytes([
input[i],
input[input.len() / 2 + i],
])));
}
}

fn predict_f64(input: &mut [u8], output: &mut [u8], samples: usize) {
for i in samples..input.len() {
input[i] = input[i].wrapping_add(input[i - samples]);
Expand Down Expand Up @@ -340,6 +368,7 @@ fn fix_endianness_and_predict(
Predictor::FloatingPoint => {
let mut buffer_copy = buf.to_vec();
match bit_depth {
16 => predict_f16(&mut buffer_copy, buf, samples),
32 => predict_f32(&mut buffer_copy, buf, samples),
64 => predict_f64(&mut buffer_copy, buf, samples),
_ => unreachable!("Caller should have validated arguments. Please file a bug."),
Expand Down Expand Up @@ -1004,6 +1033,7 @@ impl<R: Read + Seek> Decoder<R> {
)),
},
SampleFormat::IEEEFP => match max_sample_bits {
16 => DecodingResult::new_f16(buffer_size, &self.limits),
32 => DecodingResult::new_f32(buffer_size, &self.limits),
64 => DecodingResult::new_f64(buffer_size, &self.limits),
n => Err(TiffError::UnsupportedError(
Expand Down
195 changes: 195 additions & 0 deletions tests/decode_fp16_images.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
extern crate tiff;

use tiff::decoder::{Decoder, DecodingResult};
use tiff::ColorType;

use std::fs::File;
use std::path::PathBuf;

const TEST_IMAGE_DIR: &str = "./tests/images/";

/// Test a basic all white image
#[test]
fn test_white_ieee_fp16() {
let filenames = ["white-fp16.tiff"];

for filename in filenames.iter() {
let path = PathBuf::from(TEST_IMAGE_DIR).join(filename);
let img_file = File::open(path).expect("Cannot find test image!");
let mut decoder = Decoder::new(img_file).expect("Cannot create decoder");
assert_eq!(
decoder.dimensions().expect("Cannot get dimensions"),
(256, 256)
);
assert_eq!(
decoder.colortype().expect("Cannot get colortype"),
ColorType::Gray(16)
);
if let DecodingResult::F16(img) = decoder.read_image().unwrap() {
for p in img {
assert!(p == half::f16::from_f32_const(1.0));
}
} else {
panic!("Wrong data type");
}
}
}

/// Test a single black pixel, to make sure scaling is ok
#[test]
fn test_one_black_pixel_ieee_fp16() {
let filenames = ["single-black-fp16.tiff"];

for filename in filenames.iter() {
let path = PathBuf::from(TEST_IMAGE_DIR).join(filename);
let img_file = File::open(path).expect("Cannot find test image!");
let mut decoder = Decoder::new(img_file).expect("Cannot create decoder");
assert_eq!(
decoder.dimensions().expect("Cannot get dimensions"),
(256, 256)
);
assert_eq!(
decoder.colortype().expect("Cannot get colortype"),
ColorType::Gray(16)
);
if let DecodingResult::F16(img) = decoder.read_image().unwrap() {
for (i, p) in img.iter().enumerate() {
if i == 0 {
assert!(p < &half::f16::from_f32_const(0.001));
} else {
assert!(p == &half::f16::from_f32_const(1.0));
}
}
} else {
panic!("Wrong data type");
}
}
}

/// Test white with horizontal differencing predictor
#[test]
fn test_pattern_horizontal_differencing_ieee_fp16() {
let filenames = ["white-fp16-pred2.tiff"];

for filename in filenames.iter() {
let path = PathBuf::from(TEST_IMAGE_DIR).join(filename);
let img_file = File::open(path).expect("Cannot find test image!");
let mut decoder = Decoder::new(img_file).expect("Cannot create decoder");
assert_eq!(
decoder.dimensions().expect("Cannot get dimensions"),
(256, 256)
);
assert_eq!(
decoder.colortype().expect("Cannot get colortype"),
ColorType::Gray(16)
);
if let DecodingResult::F16(img) = decoder.read_image().unwrap() {
// 0, 2, 5, 8, 12, 16, 255 are black
let black = [0, 2, 5, 8, 12, 16, 255];
for (i, p) in img.iter().enumerate() {
if black.contains(&i) {
assert!(p < &half::f16::from_f32_const(0.001));
} else {
assert!(p == &half::f16::from_f32_const(1.0));
}
}
} else {
panic!("Wrong data type");
}
}
}

/// Test white with floating point predictor
#[test]
fn test_pattern_predictor_ieee_fp16() {
let filenames = ["white-fp16-pred3.tiff"];

for filename in filenames.iter() {
let path = PathBuf::from(TEST_IMAGE_DIR).join(filename);
let img_file = File::open(path).expect("Cannot find test image!");
let mut decoder = Decoder::new(img_file).expect("Cannot create decoder");
assert_eq!(
decoder.dimensions().expect("Cannot get dimensions"),
(256, 256)
);
assert_eq!(
decoder.colortype().expect("Cannot get colortype"),
ColorType::Gray(16)
);
if let DecodingResult::F16(img) = decoder.read_image().unwrap() {
// 0, 2, 5, 8, 12, 16, 255 are black
let black = [0, 2, 5, 8, 12, 16, 255];
for (i, p) in img.iter().enumerate() {
if black.contains(&i) {
assert!(p < &half::f16::from_f32_const(0.001));
} else {
assert!(p == &half::f16::from_f32_const(1.0));
}
}
} else {
panic!("Wrong data type");
}
}
}

/// Test several random images
/// we'rell compare against a pnm file, that scales from 0 (for 0.0) to 65767 (for 1.0)
#[test]
fn test_predictor_ieee_fp16() {
// first parse pnm, skip the first 4 \n
let pnm_path = PathBuf::from(TEST_IMAGE_DIR).join("random-fp16.pgm");
let pnm_bytes = std::fs::read(pnm_path).expect("Failed to read expected PNM file");

// PGM looks like this:
// ---
// P5
// #Created with GIMP
// 16 16
// 65535
// ... <big-endian bytes>
// ---
// get index of 4th \n
let byte_start = pnm_bytes
.iter()
.enumerate()
.filter(|(_, &v)| v == b'\n')
.map(|(i, _)| i)
.nth(3)
.expect("Must be 4 \\n's");

let pnm_values: Vec<f32> = pnm_bytes[(byte_start + 1)..]
.chunks(2)
.map(|slice| {
let bts = [slice[0], slice[1]];
(u16::from_be_bytes(bts) as f32) / (u16::MAX as f32)
})
.collect();
assert!(pnm_values.len() == 256);

let filenames = [
"random-fp16-pred2.tiff",
"random-fp16-pred3.tiff",
"random-fp16.tiff",
];

for filename in filenames.iter() {
let path = PathBuf::from(TEST_IMAGE_DIR).join(filename);
let img_file = File::open(path).expect("Cannot find test image!");
let mut decoder = Decoder::new(img_file).expect("Cannot create decoder");
assert_eq!(
decoder.dimensions().expect("Cannot get dimensions"),
(16, 16)
);
assert_eq!(
decoder.colortype().expect("Cannot get colortype"),
ColorType::Gray(16)
);
if let DecodingResult::F16(img) = decoder.read_image().unwrap() {
for (exp, found) in std::iter::zip(pnm_values.iter(), img.iter()) {
assert!((exp - found.to_f32()).abs() < 0.0001);
}
} else {
panic!("Wrong data type");
}
}
}
Binary file added tests/images/random-fp16-pred2.tiff
Binary file not shown.
Binary file added tests/images/random-fp16-pred3.tiff
Binary file not shown.
Binary file added tests/images/random-fp16.pgm
Binary file not shown.
Binary file added tests/images/random-fp16.tiff
Binary file not shown.
Binary file added tests/images/single-black-fp16.tiff
Binary file not shown.
Binary file added tests/images/white-fp16-pred2.tiff
Binary file not shown.
Binary file added tests/images/white-fp16-pred3.tiff
Binary file not shown.
Binary file added tests/images/white-fp16.tiff
Binary file not shown.

0 comments on commit 3bfb43e

Please sign in to comment.