Skip to content

Commit

Permalink
enforce covariance and remove as_mut until further investigation
Browse files Browse the repository at this point in the history
partial fix for #1

By implementing a function that rebind the lifetime of values we're
forcing the compiler to check that the types are covariant

Signed-off-by: Petros Angelatos <[email protected]>
  • Loading branch information
petrosagg committed Apr 13, 2021
1 parent 016f4c1 commit d0e66f0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 67 deletions.
8 changes: 8 additions & 0 deletions escher-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,13 @@ pub fn derive_rebindable(input: TokenStream) -> TokenStream {
unsafe impl<#(#impl_params),*> escher::RebindTo<'a> for #name<#(#type_params),*> {
type Out = #name<#(#out_params),*>;
}

impl escher::Rebindable for #name<#(#type_params),*> {
fn rebind<'short, 'long: 'short>(&'long self) -> &'short escher::Rebind<'short, Self>
where Self: 'long
{
self
}
}
})
}
20 changes: 12 additions & 8 deletions escher/src/escher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ unsafe impl<'a, T: ?Sized + 'static> RebindTo<'a> for &'_ mut T {

/// Marker trait for any type that implements [RebindTo] for any lifetime. All types can derive
/// this trait using the [Rebindable](escher_derive::Rebindable) derive macro.
pub trait Rebindable: for<'a> RebindTo<'a> {}
impl<T: for<'a> RebindTo<'a>> Rebindable for T {}
pub trait Rebindable: for<'a> RebindTo<'a> {
fn rebind<'short, 'long: 'short>(&'long self) -> &'short Rebind<'short, Self>
where Self: 'long;
}

/// Type-level function that takes a lifetime `'a` and a type `T` computes a new type `U` that is
/// identical to `T` except for its lifetimes that are now bound to `'a`.
Expand All @@ -70,16 +72,19 @@ impl<'fut, T: Rebindable> Escher<'fut, T> {
/// desired state when ready.
///
/// ```rust
/// use escher::Escher;
/// use escher::{Escher, Rebindable};
///
/// #[derive(Rebindable)]
/// struct MyStr<'a>(&'a str);
///
/// let escher_heart = Escher::new(|r| async move {
/// let data: Vec<u8> = vec![240, 159, 146, 150];
/// let sparkle_heart = std::str::from_utf8(&data).unwrap();
///
/// r.capture(sparkle_heart).await;
/// r.capture(MyStr(sparkle_heart)).await;
/// });
///
/// assert_eq!("💖", *escher_heart.as_ref());
/// assert_eq!("💖", escher_heart.as_ref().0);
/// ```
pub fn new<B, F>(builder: B) -> Self
where
Expand Down Expand Up @@ -135,13 +140,12 @@ impl<'fut, T: Rebindable> Escher<'fut, T> {
// The resulting reference is has all its lifetimes bound to the lifetime of self that
// contains _fut that contains all the data that ptr could be referring to because it's
// a 'static Future
unsafe { &*(self.ptr.as_ptr() as *mut _) }
unsafe { (&*self.ptr.as_ptr()).rebind() }
}

/// Get a mut reference to the inner `T` with its lifetime bound to `&mut self`
pub fn as_mut<'a>(&'a mut self) -> &mut Rebind<'a, T> {
// SAFETY: see safety argument of Self::as_ref
unsafe { &mut *(self.ptr.as_ptr() as *mut _) }
unimplemented!()
}
}

Expand Down
55 changes: 27 additions & 28 deletions escher/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,19 @@
//! The simplest way to use Escher is to create a reference of some data and then capture it:
//!
//! ```rust
//! use escher::Escher;
//! use escher::{Escher, Rebindable};
//!
//! #[derive(Rebindable)]
//! struct MyStr<'a>(&'a str);
//!
//! let escher_heart = Escher::new(|r| async move {
//! let data: Vec<u8> = vec![240, 159, 146, 150];
//! let sparkle_heart = std::str::from_utf8(&data).unwrap();
//!
//! r.capture(sparkle_heart).await;
//! r.capture(MyStr(sparkle_heart)).await;
//! });
//!
//! assert_eq!("💖", *escher_heart.as_ref());
//! assert_eq!("💖", escher_heart.as_ref().0);
//! ```
//!
//! ## Capturing both a `Vec<u8>` and a `&str` view into it
Expand Down Expand Up @@ -76,27 +79,27 @@
//! assert_eq!(240, escher_heart.as_ref().data[0]);
//! assert_eq!("💖", escher_heart.as_ref().s);
//! ```
//!
//! ## Capturing a mutable `&mut str` view into a `Vec<u8>`
//!
//! If you capture a mutable reference to some piece of data then you cannot capture the data
//! itself like the previous example. This is mandatory as doing otherwise would create two mutable
//! references into the same piece of data which is not allowed.
//!
//! ```rust
//! use escher::Escher;
//!
//! let mut name = Escher::new(|r| async move {
//! let mut data: Vec<u8> = vec![101, 115, 99, 104, 101, 114];
//! let name = std::str::from_utf8_mut(&mut data).unwrap();
//!
//! r.capture(name).await;
//! });
//!
//! assert_eq!("escher", *name.as_ref());
//! name.as_mut().make_ascii_uppercase();
//! assert_eq!("ESCHER", *name.as_ref());
//! ```
//
// ## Capturing a mutable `&mut str` view into a `Vec<u8>`
//
// If you capture a mutable reference to some piece of data then you cannot capture the data
// itself like the previous example. This is mandatory as doing otherwise would create two mutable
// references into the same piece of data which is not allowed.
//
// ```rust
// use escher::Escher;
//
// let mut name = Escher::new(|r| async move {
// let mut data: Vec<u8> = vec![101, 115, 99, 104, 101, 114];
// let name = std::str::from_utf8_mut(&mut data).unwrap();
//
// r.capture(name).await;
// });
//
// assert_eq!("escher", *name.as_ref());
// name.as_mut().make_ascii_uppercase();
// assert_eq!("ESCHER", *name.as_ref());
// ```
//!
//! ## Capturing multiple mixed references
//!
Expand All @@ -123,10 +126,6 @@
//!
//! assert_eq!(Box::new(42), *my_value.as_ref().int_data);
//! assert_eq!(3.14, *my_value.as_ref().float_ref);
//!
//! *my_value.as_mut().float_ref = (*my_value.as_ref().int_ref as f32) * 2.0;
//!
//! assert_eq!(84.0, *my_value.as_ref().float_ref);
//! ```
mod escher;
Expand Down
63 changes: 32 additions & 31 deletions escher/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,40 @@
use super::*;
use crate as escher;

#[test]
fn simple_ref() {
let escher_heart = Escher::new(|r| async move {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

r.capture(sparkle_heart).await;
});

assert_eq!("💖", *escher_heart.as_ref());
}

#[test]
fn simple_mut_ref() {
let escher_heart = Escher::new(|r| async move {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

r.capture(sparkle_heart).await;
});

assert_eq!("💖", *escher_heart.as_ref());
}
// #[test]
// fn test_dtonlay() {
// use std::cell::Cell;
//
// #[derive(Rebindable)]
// struct Struct<'a>(fn(&'a String));
//
// fn main() {
// static STRING: String = String::new();
// thread_local!(static CELL: Cell<&'static String> = Cell::new(&STRING));
// let escher = Escher::new(|r| async {
// r.capture(Struct(|x| CELL.with(|cell| cell.set(x)))).await;
// });
// let mut string = Ok(".".repeat(3));
// let f = escher.as_ref().0;
// let s = string.as_ref().unwrap();
// f(s);
// string = Err((s.as_ptr(), 100usize, 100usize));
// CELL.with(|cell| println!("{}", cell.get()));
// string.unwrap_err();
// }
// }

#[test]
#[should_panic(expected = "capture no longer live")]
fn adversarial_sync_fn() {
#[derive(Rebindable)]
struct MyStr<'a>(&'a str);

Escher::new(|r| {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

let _ = r.capture(sparkle_heart);
let _ = r.capture(MyStr(sparkle_heart));

// dummy future to satisfy escher
std::future::ready(())
Expand All @@ -43,11 +45,14 @@ fn adversarial_sync_fn() {
#[test]
#[should_panic(expected = "captured value outside of async stack")]
fn adversarial_capture_non_stack() {
#[derive(Rebindable)]
struct MyStr<'a>(&'a str);

Escher::new(|r| {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

let fut = r.capture(sparkle_heart);
let fut = r.capture(MyStr(sparkle_heart));
// make it appear as if capture is still alive
std::mem::forget(fut);

Expand All @@ -65,15 +70,13 @@ fn capture_enum() {
None,
}

let mut escher_heart = Escher::new(|r| async move {
let escher_heart = Escher::new(|r| async move {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

r.capture(MaybeStr::Some(sparkle_heart)).await;
});
assert_eq!(MaybeStr::Some("💖"), *escher_heart.as_ref());
*escher_heart.as_mut() = MaybeStr::None;
assert_eq!(MaybeStr::None, *escher_heart.as_ref());
}

#[test]
Expand All @@ -85,7 +88,7 @@ fn capture_union() {
none: (),
}

let mut escher_heart = Escher::new(|r| async move {
let escher_heart = Escher::new(|r| async move {
let data: Vec<u8> = vec![240, 159, 146, 150];
let sparkle_heart = std::str::from_utf8(&data).unwrap();

Expand All @@ -97,8 +100,6 @@ fn capture_union() {

unsafe {
assert_eq!("💖", escher_heart.as_ref().some);
escher_heart.as_mut().none = ();
assert_eq!((), escher_heart.as_ref().none);
}
}

Expand Down

0 comments on commit d0e66f0

Please sign in to comment.