Skip to content

Commit

Permalink
Merge pull request #523 from madsmtm/restrict-this-declare-class
Browse files Browse the repository at this point in the history
Restrict `this` parameters to `Self`-like types in `declare_class!`
  • Loading branch information
madsmtm authored Oct 3, 2023
2 parents 53cbfc6 + 8c6b293 commit 2ad20c3
Show file tree
Hide file tree
Showing 14 changed files with 1,309 additions and 1,461 deletions.
2 changes: 2 additions & 0 deletions crates/objc2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
* Allow cloning `Id<AnyObject>`.
* **BREAKING**: Restrict message sending to `&mut` references to things that
implement `IsAllowedMutable`.
* Disallow the ability to use non-`Self`-like types as the receiver in
`declare_class!`.

### Removed
* **BREAKING**: Removed `ProtocolType` implementation for `NSObject`.
Expand Down
249 changes: 247 additions & 2 deletions crates/objc2/src/__macro_helpers/declare_class.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
#[cfg(all(debug_assertions, feature = "verify"))]
use alloc::vec::Vec;
use core::marker::PhantomData;
#[cfg(all(debug_assertions, feature = "verify"))]
use std::collections::HashSet;

#[cfg(all(debug_assertions, feature = "verify"))]
use crate::runtime::{AnyProtocol, MethodDescription};

use objc2_encode::Encoding;

use crate::declare::{ClassBuilder, IvarType};
use crate::encode::Encode;
use crate::rc::{Allocated, Id};
use crate::runtime::{AnyClass, MethodImplementation, Sel};
use crate::runtime::{AnyObject, MessageReceiver};
use crate::{ClassType, Message};
use crate::{ClassType, Message, ProtocolType};

use super::{CopyOrMutCopy, Init, MaybeUnwrap, New, Other};
use crate::mutability;
Expand Down Expand Up @@ -52,7 +63,7 @@ where
// restrict it here to only be when the selector is `init`.
//
// Additionally, the receiver and return type must have the same generic
// generic parameter `T`.
// parameter `T`.
impl<Ret, T> MessageRecieveId<Allocated<T>, Ret> for Init
where
T: Message,
Expand Down Expand Up @@ -190,3 +201,237 @@ where
{
// Noop
}

#[derive(Debug)]
pub struct ClassBuilderHelper<T: ?Sized> {
builder: ClassBuilder,
p: PhantomData<T>,
}

#[track_caller]
fn failed_declaring_class(name: &str) -> ! {
panic!("could not create new class {name}. Perhaps a class with that name already exists?")
}

impl<T: ?Sized + ClassType> ClassBuilderHelper<T> {
#[inline]
#[track_caller]
#[allow(clippy::new_without_default)]
pub fn new() -> Self
where
T::Super: ClassType,
{
let builder = match ClassBuilder::new(T::NAME, <T::Super as ClassType>::class()) {
Some(builder) => builder,
None => failed_declaring_class(T::NAME),
};

Self {
builder,
p: PhantomData,
}
}

#[inline]
pub fn add_protocol_methods<P>(&mut self) -> ClassProtocolMethodsBuilder<'_, T>
where
P: ?Sized + ProtocolType,
{
let protocol = P::protocol();

if let Some(protocol) = protocol {
self.builder.add_protocol(protocol);
}

#[cfg(all(debug_assertions, feature = "verify"))]
{
ClassProtocolMethodsBuilder {
builder: self,
protocol,
required_instance_methods: protocol
.map(|p| p.method_descriptions(true))
.unwrap_or_default(),
optional_instance_methods: protocol
.map(|p| p.method_descriptions(false))
.unwrap_or_default(),
registered_instance_methods: HashSet::new(),
required_class_methods: protocol
.map(|p| p.class_method_descriptions(true))
.unwrap_or_default(),
optional_class_methods: protocol
.map(|p| p.class_method_descriptions(false))
.unwrap_or_default(),
registered_class_methods: HashSet::new(),
}
}

#[cfg(not(all(debug_assertions, feature = "verify")))]
{
ClassProtocolMethodsBuilder { builder: self }
}
}

// Addition: This restricts to callee `T`
#[inline]
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = T>,
{
// SAFETY: Checked by caller
unsafe { self.builder.add_method(sel, func) }
}

#[inline]
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = AnyClass>,
{
// SAFETY: Checked by caller
unsafe { self.builder.add_class_method(sel, func) }
}

#[inline]
pub fn add_static_ivar<I: IvarType>(&mut self) {
self.builder.add_static_ivar::<I>()
}

#[inline]
pub fn register(self) -> &'static AnyClass {
self.builder.register()
}
}

/// Helper for ensuring that:
/// - Only methods on the protocol are overriden.
/// - TODO: The methods have the correct signature.
/// - All required methods are overridden.
#[derive(Debug)]
pub struct ClassProtocolMethodsBuilder<'a, T: ?Sized> {
builder: &'a mut ClassBuilderHelper<T>,
#[cfg(all(debug_assertions, feature = "verify"))]
protocol: Option<&'static AnyProtocol>,
#[cfg(all(debug_assertions, feature = "verify"))]
required_instance_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
optional_instance_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
registered_instance_methods: HashSet<Sel>,
#[cfg(all(debug_assertions, feature = "verify"))]
required_class_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
optional_class_methods: Vec<MethodDescription>,
#[cfg(all(debug_assertions, feature = "verify"))]
registered_class_methods: HashSet<Sel>,
}

impl<T: ?Sized + ClassType> ClassProtocolMethodsBuilder<'_, T> {
// Addition: This restricts to callee `T`
#[inline]
pub unsafe fn add_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = T>,
{
#[cfg(all(debug_assertions, feature = "verify"))]
if let Some(protocol) = self.protocol {
let _types = self
.required_instance_methods
.iter()
.chain(&self.optional_instance_methods)
.find(|desc| desc.sel == sel)
.map(|desc| desc.types)
.unwrap_or_else(|| {
panic!(
"failed overriding protocol method -[{protocol} {sel}]: method not found"
)
});
}

// SAFETY: Checked by caller
unsafe { self.builder.add_method(sel, func) };

#[cfg(all(debug_assertions, feature = "verify"))]
if !self.registered_instance_methods.insert(sel) {
unreachable!("already added")
}
}

#[inline]
pub unsafe fn add_class_method<F>(&mut self, sel: Sel, func: F)
where
F: MethodImplementation<Callee = AnyClass>,
{
#[cfg(all(debug_assertions, feature = "verify"))]
if let Some(protocol) = self.protocol {
let _types = self
.required_class_methods
.iter()
.chain(&self.optional_class_methods)
.find(|desc| desc.sel == sel)
.map(|desc| desc.types)
.unwrap_or_else(|| {
panic!(
"failed overriding protocol method +[{protocol} {sel}]: method not found"
)
});
}

// SAFETY: Checked by caller
unsafe { self.builder.add_class_method(sel, func) };

#[cfg(all(debug_assertions, feature = "verify"))]
if !self.registered_class_methods.insert(sel) {
unreachable!("already added")
}
}

#[cfg(all(debug_assertions, feature = "verify"))]
pub fn finish(self) {
let superclass = self.builder.builder.superclass();

if let Some(protocol) = self.protocol {
for desc in &self.required_instance_methods {
if self.registered_instance_methods.contains(&desc.sel) {
continue;
}

// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
if superclass
.and_then(|superclass| superclass.instance_method(desc.sel))
.is_some()
{
continue;
}

panic!(
"must implement required protocol method -[{protocol} {}]",
desc.sel
)
}
}

if let Some(protocol) = self.protocol {
for desc in &self.required_class_methods {
if self.registered_class_methods.contains(&desc.sel) {
continue;
}

// TODO: Don't do this when `NS_PROTOCOL_REQUIRES_EXPLICIT_IMPLEMENTATION`
if superclass
.and_then(|superclass| superclass.class_method(desc.sel))
.is_some()
{
continue;
}

panic!(
"must implement required protocol method +[{protocol} {}]",
desc.sel
);
}
}
}

#[inline]
#[cfg(not(all(debug_assertions, feature = "verify")))]
pub fn finish(self) {}
}
Loading

0 comments on commit 2ad20c3

Please sign in to comment.