diff --git a/Sources/Knit/ConcurrencyAttribute.swift b/Sources/Knit/ConcurrencyAttribute.swift new file mode 100644 index 0000000..bf7ebe2 --- /dev/null +++ b/Sources/Knit/ConcurrencyAttribute.swift @@ -0,0 +1,16 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +/// The possible concurrency isolation for a registration. +public enum ConcurrencyAttribute { + /// We do not currently have a way to forward this information through Swinject Behavior hooks + /// so registrations that come from behaviors will be unknown. + case unknown + + /// The default. + case nonisolated + + /// Corresponds to the `@MainActor` attribute. + case MainActor +} diff --git a/Sources/Knit/Module/Container+AbstractRegistration.swift b/Sources/Knit/Module/Container+AbstractRegistration.swift index a8238e7..0545cee 100644 --- a/Sources/Knit/Module/Container+AbstractRegistration.swift +++ b/Sources/Knit/Module/Container+AbstractRegistration.swift @@ -12,9 +12,10 @@ extension Container { public func registerAbstract( _ serviceType: Service.Type, name: String? = nil, + concurrency: ConcurrencyAttribute = .nonisolated, file: String = #fileID ) { - let registration = RealAbstractRegistration(name: name, file: file) + let registration = RealAbstractRegistration(name: name, file: file, concurrency: concurrency) abstractRegistrations().abstractRegistrations.append(registration) } @@ -26,9 +27,10 @@ extension Container { public func registerAbstract( _ serviceType: Optional.Type, name: String? = nil, + concurrency: ConcurrencyAttribute = .nonisolated, file: String = #fileID ) { - let registration = OptionalAbstractRegistration(name: name, file: file) + let registration = OptionalAbstractRegistration(name: name, file: file, concurrency: concurrency) abstractRegistrations().abstractRegistrations.append(registration) } @@ -50,6 +52,7 @@ extension Container { internal struct RegistrationKey: Hashable, Equatable { let typeIdentifier: ObjectIdentifier let name: String? + let concurrency: ConcurrencyAttribute } /// Protocol version to allow storing generic types an array @@ -60,6 +63,7 @@ internal protocol AbstractRegistration { var file: String { get } var name: String? { get } var key: RegistrationKey { get } + var concurrency: ConcurrencyAttribute { get } /// Register a placeholder registration to fill the unfulfilled abstract registration /// This placeholder cannot be resolved @@ -89,8 +93,14 @@ fileprivate struct RealAbstractRegistration: AbstractRegistration { var serviceType: ServiceType.Type { ServiceType.self } + let concurrency: ConcurrencyAttribute + var key: RegistrationKey { - return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name) + return .init( + typeIdentifier: ObjectIdentifier(ServiceType.self), + name: name, + concurrency: concurrency + ) } func registerPlaceholder( @@ -113,8 +123,10 @@ fileprivate struct OptionalAbstractRegistration: AbstractRegistrati var serviceType: ServiceType.Type { ServiceType.self } + let concurrency: ConcurrencyAttribute + var key: RegistrationKey { - return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name) + return .init(typeIdentifier: ObjectIdentifier(ServiceType.self), name: name, concurrency: concurrency) } func registerPlaceholder( @@ -171,12 +183,23 @@ extension Container { toService entry: ServiceEntry, withName name: String? ) { - let id = RegistrationKey(typeIdentifier: ObjectIdentifier(Type.self), name: name) + let id = RegistrationKey( + typeIdentifier: ObjectIdentifier(Type.self), + name: name, + concurrency: .unknown + ) concreteRegistrations.insert(id) } var unfulfilledRegistrations: [any AbstractRegistration] { - abstractRegistrations.filter { !concreteRegistrations.contains($0.key) } + abstractRegistrations.filter { abstractRegistration in + let abstractKey = abstractRegistration.key + return !concreteRegistrations.contains { concreteKey in + // We need to ignore the concurrency attribute currently due to Swinject limitations + concreteKey.typeIdentifier == abstractKey.typeIdentifier && + concreteKey.name == abstractKey.name + } + } } // Throws an error if any abstract registrations have not been implemented diff --git a/Sources/KnitCodeGen/Configuration.swift b/Sources/KnitCodeGen/Configuration.swift index c44ddcb..e969fa5 100644 --- a/Sources/KnitCodeGen/Configuration.swift +++ b/Sources/KnitCodeGen/Configuration.swift @@ -13,7 +13,6 @@ public struct Configuration: Encodable, Sendable { public var directives: KnitDirectives public enum AssemblyType: String, Encodable, Sendable { - /// `Swinject.Assembly` case moduleAssembly = "ModuleAssembly" case autoInitAssembly = "AutoInitModuleAssembly" case abstractAssembly = "AbstractAssembly" diff --git a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift index eeb97d9..633ddb3 100644 --- a/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift +++ b/Sources/KnitCodeGen/FunctionCallRegistrationParsing.swift @@ -285,9 +285,19 @@ private func getConcurrencyModifier( arguments: LabeledExprListSyntax, trailingClosure: ClosureExprSyntax? ) -> String? { + // Detects concrete registrations that use the explicitly named closure argument if arguments.contains(where: {$0.label?.text == "mainActorFactory" }) { return "@MainActor" } + // Detects abstract registrations + for arg in arguments { + guard arg.label?.text == "concurrency" else { continue } + // Corresponds to `(concurrency: .MainActor)` + // declName is what follows the period + if arg.expression.as(MemberAccessExprSyntax.self)?.declName.baseName.text == "MainActor" { + return "@MainActor" + } + } guard let signature = trailingClosure?.signature else { return nil } for att in signature.attributes { guard case let .attribute(attributeSyntax) = att else { diff --git a/Tests/KnitCodeGenTests/RegistrationParsingTests.swift b/Tests/KnitCodeGenTests/RegistrationParsingTests.swift index f7f136f..d7ac2a5 100644 --- a/Tests/KnitCodeGenTests/RegistrationParsingTests.swift +++ b/Tests/KnitCodeGenTests/RegistrationParsingTests.swift @@ -134,6 +134,16 @@ final class RegistrationParsingTests: XCTestCase { serviceName: "AType", name: "service" ) + + try assertRegistrationString( + """ + container.registerAbstract(AType.self, name: "service", concurrency: .MainActor) + """, + serviceName: "AType", + name: "service", + concurrencyModifier: "@MainActor" + ) + } func testForwardedRegistration() throws { @@ -637,6 +647,7 @@ private func assertRegistrationString( accessLevel: AccessLevel = .internal, name: String? = nil, isForwarded: Bool = false, + concurrencyModifier: String? = nil, file: StaticString = #filePath, line: UInt = #line ) throws { let functionCall = try XCTUnwrap(FunctionCallExprSyntax("\(raw: string)" as ExprSyntax)) @@ -651,6 +662,7 @@ private func assertRegistrationString( XCTAssertEqual(registration?.accessLevel, accessLevel, file: file, line: line) XCTAssertEqual(registration?.name, name, file: file, line: line) XCTAssertEqual(registration?.isForwarded, isForwarded, file: file, line: line) + XCTAssertEqual(registration?.concurrencyModifier, concurrencyModifier, file: file, line: line) } /// Assert that multiple registrations exist within the string.