Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

impl Versionize for HashSet/HashMap and Box<[T]>/ABox<[T]> #1441

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 112 additions & 3 deletions utils/tfhe-versionable/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod upgrade;

use aligned_vec::{ABox, AVec};
use num_complex::Complex;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::error::Error;
use std::fmt::Display;
Expand Down Expand Up @@ -249,6 +250,30 @@ impl<T: Unversionize> Unversionize for Box<T> {
}
}

impl<T: VersionizeSlice + Clone> Versionize for Box<[T]> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
T::versionize_slice(self)
}
}

impl<T: VersionizeVec + Clone> VersionizeOwned for Box<[T]> {
type VersionedOwned = T::VersionedVec;

fn versionize_owned(self) -> Self::VersionedOwned {
T::versionize_vec(self.iter().cloned().collect())
}
}

impl<T: UnversionizeVec + Clone> Unversionize for Box<[T]> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
T::unversionize_vec(versioned).map(|unver| unver.into_boxed_slice())
}
}

impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> NotVersioned for Box<[T]> {}

impl<T: VersionizeSlice> Versionize for Vec<T> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;

Expand Down Expand Up @@ -446,16 +471,16 @@ impl<T: Versionize> Versionize for ABox<T> {
}
}

impl<T: VersionizeOwned + Copy> VersionizeOwned for ABox<T> {
impl<T: VersionizeOwned + Clone> VersionizeOwned for ABox<T> {
// Alignment doesn't matter for versioned types
type VersionedOwned = Box<T::VersionedOwned>;

fn versionize_owned(self) -> Self::VersionedOwned {
Box::new(T::versionize_owned(*self))
Box::new(T::versionize_owned(T::clone(&self)))
}
}

impl<T: Unversionize + Copy> Unversionize for ABox<T>
impl<T: Unversionize + Clone> Unversionize for ABox<T>
where
T::VersionedOwned: Clone,
{
Expand All @@ -464,6 +489,30 @@ where
}
}

impl<T: VersionizeSlice + Clone> Versionize for ABox<[T]> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
T::versionize_slice(self)
}
}

impl<T: VersionizeVec + Clone> VersionizeOwned for ABox<[T]> {
type VersionedOwned = T::VersionedVec;

fn versionize_owned(self) -> Self::VersionedOwned {
T::versionize_vec(self.iter().cloned().collect())
}
}

impl<T: UnversionizeVec + Clone> Unversionize for ABox<[T]> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
T::unversionize_vec(versioned).map(|unver| AVec::from_iter(0, unver).into_boxed_slice())
}
}

impl<T: NotVersioned + Clone + Serialize + DeserializeOwned> NotVersioned for ABox<[T]> {}

impl<T: VersionizeSlice> Versionize for AVec<T> {
type Versioned<'vers> = T::VersionedSlice<'vers> where T: 'vers;

Expand Down Expand Up @@ -568,3 +617,63 @@ impl<T: Unversionize, U: Unversionize, V: Unversionize> Unversionize for (T, U,
}

impl<T: NotVersioned, U: NotVersioned, V: NotVersioned> NotVersioned for (T, U, V) {}

// converts to `Vec<T::Versioned>` for the versioned type, so we don't have to derive
// Eq/Hash on it.
impl<T: Versionize> Versionize for HashSet<T> {
type Versioned<'vers> = Vec<T::Versioned<'vers>>
where
T: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
self.iter().map(|val| val.versionize()).collect()
}
}

impl<T: VersionizeOwned> VersionizeOwned for HashSet<T> {
type VersionedOwned = Vec<T::VersionedOwned>;

fn versionize_owned(self) -> Self::VersionedOwned {
self.into_iter().map(|val| val.versionize_owned()).collect()
}
}

impl<T: Unversionize + std::hash::Hash + Eq> Unversionize for HashSet<T> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
versioned
.into_iter()
.map(|val| T::unversionize(val))
.collect()
}
}

// converts to `Vec<(K::Versioned, V::Versioned)>` for the versioned type, so we don't have to
// derive Eq/Hash on it.
impl<K: Versionize, V: Versionize> Versionize for HashMap<K, V> {
type Versioned<'vers> = Vec<(K::Versioned<'vers>, V::Versioned<'vers>)> where K: 'vers, V: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
self.iter()
.map(|(key, val)| (key.versionize(), val.versionize()))
.collect()
}
}

impl<K: VersionizeOwned, V: VersionizeOwned> VersionizeOwned for HashMap<K, V> {
type VersionedOwned = Vec<(K::VersionedOwned, V::VersionedOwned)>;

fn versionize_owned(self) -> Self::VersionedOwned {
self.into_iter()
.map(|(key, val)| (key.versionize_owned(), val.versionize_owned()))
.collect()
}
}

impl<K: Unversionize + std::hash::Hash + Eq, V: Unversionize> Unversionize for HashMap<K, V> {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, UnversionizeError> {
versioned
.into_iter()
.map(|(key, val)| Ok((K::unversionize(key)?, V::unversionize(val)?)))
.collect()
}
}
Loading