Skip to content

Commit

Permalink
feat(versionable): impl Versionize for Vec<Vec<T>>
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Oct 1, 2024
1 parent 75d2457 commit 04c6f18
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 7 deletions.
116 changes: 110 additions & 6 deletions utils/tfhe-versionable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ macro_rules! impl_scalar_versionize {
}

impl NotVersioned for $t {}

impl NotVersioned for Vec<$t> {}
};
}

Expand Down Expand Up @@ -315,7 +313,35 @@ impl<T: UnversionizeVec + Clone> Unversionize for Box<[T]> {
}
}

impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> NotVersioned for Box<[T]> {}
impl<T: VersionizeVec + Clone> VersionizeVec for Box<[T]> {
type VersionedVec = Vec<T::VersionedVec>;

fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec.into_iter()
.map(|inner| inner.versionize_owned())
.collect()
}
}

impl<T: VersionizeSlice> VersionizeSlice for Box<[T]> {
type VersionedSlice<'vers> = Vec<T::VersionedSlice<'vers>> where T: 'vers;

fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice
.iter()
.map(|inner| T::versionize_slice(inner))
.collect()
}
}

impl<T: UnversionizeVec + Clone> UnversionizeVec for Box<[T]> {
fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, UnversionizeError> {
versioned
.into_iter()
.map(Box::<[T]>::unversionize)
.collect()
}
}

impl<T: VersionizeSlice> Versionize for Vec<T> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;
Expand All @@ -333,6 +359,42 @@ impl<T: VersionizeVec> VersionizeOwned for Vec<T> {
}
}

impl<T: UnversionizeVec> Unversionize for Vec<T> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
T::unversionize_vec(versioned)
}
}

impl<T: VersionizeVec> VersionizeVec for Vec<T> {
type VersionedVec = Vec<T::VersionedVec>;

fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec.into_iter()
.map(|inner| T::versionize_vec(inner))
.collect()
}
}

impl<T: VersionizeSlice> VersionizeSlice for Vec<T> {
type VersionedSlice<'vers> = Vec<T::VersionedSlice<'vers>> where T: 'vers;

fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice
.iter()
.map(|inner| T::versionize_slice(inner))
.collect()
}
}

impl<T: UnversionizeVec> UnversionizeVec for Vec<T> {
fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, UnversionizeError> {
versioned
.into_iter()
.map(|inner| T::unversionize_vec(inner))
.collect()
}
}

impl<T: VersionizeSlice + Clone> Versionize for [T] {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;

Expand All @@ -349,9 +411,24 @@ impl<T: VersionizeVec + Clone> VersionizeOwned for &[T] {
}
}

impl<T: UnversionizeVec> Unversionize for Vec<T> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
T::unversionize_vec(versioned)
impl<T: VersionizeVec + Clone> VersionizeVec for &[T] {
type VersionedVec = Vec<T::VersionedVec>;

fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec.into_iter()
.map(|inner| T::versionize_vec(inner.to_vec()))
.collect()
}
}

impl<'a, T: VersionizeSlice> VersionizeSlice for &'a [T] {
type VersionedSlice<'vers> = Vec<T::VersionedSlice<'vers>> where T: 'vers, 'a: 'vers;

fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice
.iter()
.map(|inner| T::versionize_slice(inner))
.collect()
}
}

Expand Down Expand Up @@ -386,6 +463,33 @@ impl<const N: usize, T: UnversionizeVec + Clone> Unversionize for [T; N] {
}
}

impl<const N: usize, T: VersionizeVec + Clone> VersionizeVec for [T; N] {
type VersionedVec = Vec<T::VersionedVec>;

fn versionize_vec(vec: Vec<Self>) -> Self::VersionedVec {
vec.into_iter()
.map(|inner| inner.versionize_owned())
.collect()
}
}

impl<const N: usize, T: VersionizeSlice> VersionizeSlice for [T; N] {
type VersionedSlice<'vers> = Vec<T::VersionedSlice<'vers>> where T: 'vers;

fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> {
slice
.iter()
.map(|inner| T::versionize_slice(inner))
.collect()
}
}

impl<const N: usize, T: UnversionizeVec + Clone> UnversionizeVec for [T; N] {
fn unversionize_vec(versioned: Self::VersionedVec) -> Result<Vec<Self>, UnversionizeError> {
versioned.into_iter().map(<[T; N]>::unversionize).collect()
}
}

impl Versionize for String {
type Versioned<'vers> = &'vers str;

Expand Down
39 changes: 38 additions & 1 deletion utils/tfhe-versionable/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ use aligned_vec::{ABox, AVec};
use num_complex::Complex;
use tfhe_versionable::{Unversionize, Versionize};

use backward_compat::MyStructVersions;
use backward_compat::{CustomVersions, MyStructVersions};

#[derive(PartialEq, Clone, Debug, Versionize)]
#[versionize(CustomVersions)]
struct Custom(u32);

#[derive(PartialEq, Clone, Debug, Versionize)]
#[versionize(MyStructVersions)]
Expand All @@ -19,6 +23,9 @@ pub struct MyStruct {
base_box: Box<u8>,
sliced_box: Box<[u16; 50]>,
base_vec: Vec<u32>,
base_vec_vec: Vec<Vec<u32>>,
custom_vec_vec: Vec<Vec<Custom>>,
custom_vec_vec_vec: Vec<Vec<Vec<Custom>>>,
s: String,
opt: Option<u64>,
phantom: PhantomData<u128>,
Expand All @@ -35,13 +42,21 @@ pub struct MyStruct {
mod backward_compat {
use tfhe_versionable::VersionsDispatch;

use crate::Custom;

use super::MyStruct;

#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum MyStructVersions {
V0(MyStruct),
}

#[derive(VersionsDispatch)]
#[allow(unused)]
pub enum CustomVersions {
V0(Custom),
}
}

#[test]
Expand All @@ -51,6 +66,28 @@ fn test_types() {
base_box: Box::new(42),
sliced_box: vec![11; 50].into_boxed_slice().try_into().unwrap(),
base_vec: vec![1234, 5678],
base_vec_vec: vec![vec![1234, 5678], vec![9012, 3456]],
custom_vec_vec: vec![
vec![9876, 5432, 1987, 6543]
.into_iter()
.map(Custom)
.collect(),
vec![1098, 7654, 3210, 9876]
.into_iter()
.map(Custom)
.collect(),
],
custom_vec_vec_vec: vec![
vec![
vec![9876, 5432].into_iter().map(Custom).collect(),
vec![1987, 6543].into_iter().map(Custom).collect(),
],
vec![
vec![1098, 7654].into_iter().map(Custom).collect(),
vec![3210, 9876].into_iter().map(Custom).collect(),
],
],

s: String::from("test"),
opt: Some(0xdeadbeef),
phantom: PhantomData,
Expand Down

0 comments on commit 04c6f18

Please sign in to comment.