diff --git a/Sources/Knit/Module/AbstractAssembly.swift b/Sources/Knit/Module/AbstractAssembly.swift new file mode 100644 index 0000000..eca3f18 --- /dev/null +++ b/Sources/Knit/Module/AbstractAssembly.swift @@ -0,0 +1,8 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Foundation + +/// An AbstractAssembly can only contain abstract registrations and should not be initialised. +protocol AbstractAssembly: ModuleAssembly { } diff --git a/Sources/KnitCodeGen/AssemblyParsing.swift b/Sources/KnitCodeGen/AssemblyParsing.swift index d99e999..f7658d8 100644 --- a/Sources/KnitCodeGen/AssemblyParsing.swift +++ b/Sources/KnitCodeGen/AssemblyParsing.swift @@ -51,6 +51,10 @@ func parseSyntaxTree( throw AssemblyParsingError.missingModuleName } + guard let assemblyType = assemblyFileVisitor.assemblyType else { + throw AssemblyParsingError.missingAssemblyType + } + errorsToPrint.append(contentsOf: assemblyFileVisitor.assemblyErrors) errorsToPrint.append(contentsOf: assemblyFileVisitor.registrationErrors) @@ -63,6 +67,7 @@ func parseSyntaxTree( return Configuration( name: name, + assemblyType: assemblyType, registrations: assemblyFileVisitor.registrations, registrationsIntoCollections: assemblyFileVisitor.registrationsIntoCollections, imports: assemblyFileVisitor.imports, @@ -77,6 +82,8 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor { private(set) var moduleName: String? + private(set) var assemblyType: String? + private var classDeclVisitor: ClassDeclVisitor? private(set) var assemblyErrors: [Error] = [] @@ -105,11 +112,11 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor { } override func visit(_ node: StructDeclSyntax) -> SyntaxVisitorContinueKind { - return visitAssemblyType(node) + return visitAssemblyType(node, node.inheritanceClause) } override func visit(_ node: ClassDeclSyntax) -> SyntaxVisitorContinueKind { - return visitAssemblyType(node) + return visitAssemblyType(node, node.inheritanceClause) } override func visit(_ node: ImportDeclSyntax) -> SyntaxVisitorContinueKind { @@ -130,7 +137,7 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor { return self.visitIfNode(node) } - private func visitAssemblyType(_ node: NamedDeclSyntax) -> SyntaxVisitorContinueKind { + private func visitAssemblyType(_ node: NamedDeclSyntax, _ inheritance: InheritanceClauseSyntax?) -> SyntaxVisitorContinueKind { guard classDeclVisitor == nil else { // Only the first class declaration should be visited return .skipChildren @@ -143,6 +150,10 @@ private class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor { } moduleName = node.moduleNameForAssembly + let inheritedTypes = inheritance?.inheritedTypes.map { + $0.type.description.trimmingCharacters(in: .whitespaces) + } + self.assemblyType = inheritedTypes?.first(where: { $0.hasSuffix("Assembly")}) classDeclVisitor = ClassDeclVisitor(viewMode: .fixedUp, directives: directives) classDeclVisitor?.walk(node) return .skipChildren @@ -233,6 +244,7 @@ extension NamedDeclSyntax { enum AssemblyParsingError: Error { case fileReadError(Error, path: String) case missingModuleName + case missingAssemblyType case parsingError } @@ -251,6 +263,8 @@ extension AssemblyParsingError: LocalizedError { "Is your Assembly file setup correctly?" case .parsingError: return "There were one or more errors parsing the assembly file" + case .missingAssemblyType: + return "Assembly files must inherit from an *Assembly type" } } diff --git a/Sources/KnitCodeGen/Configuration.swift b/Sources/KnitCodeGen/Configuration.swift index 5596a8c..962b4f5 100644 --- a/Sources/KnitCodeGen/Configuration.swift +++ b/Sources/KnitCodeGen/Configuration.swift @@ -9,6 +9,7 @@ public struct Configuration: Encodable { /// Name of the module for this configuration. public var name: String + public var assemblyType: String public var registrations: [Registration] public var registrationsIntoCollections: [RegistrationIntoCollection] @@ -18,12 +19,14 @@ public struct Configuration: Encodable { public init( name: String, + assemblyType: String = "Assembly", registrations: [Registration], registrationsIntoCollections: [RegistrationIntoCollection], imports: [ModuleImport] = [], targetResolver: String ) { self.name = name + self.assemblyType = assemblyType self.registrations = registrations self.registrationsIntoCollections = registrationsIntoCollections self.imports = imports @@ -32,6 +35,7 @@ public struct Configuration: Encodable { public enum CodingKeys: CodingKey { case name + case assemblyType case registrations } @@ -48,10 +52,6 @@ public extension Configuration { } func makeUnitTestSourceFile() throws -> SourceFileSyntax { - var allImports = imports - allImports.append(try .testable(name: name)) - allImports.append(try .named("XCTest")) - return try UnitTestSourceFile.make( configuration: self ) diff --git a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift index 4037972..5cba08d 100644 --- a/Tests/KnitCodeGenTests/AssemblyParsingTests.swift +++ b/Tests/KnitCodeGenTests/AssemblyParsingTests.swift @@ -13,7 +13,7 @@ final class AssemblyParsingTests: XCTestCase { let sourceFile: SourceFileSyntax = """ import A import B // Comment after import should be stripped - class FooTestAssembly: Assembly { } + class FooTestAssembly: ModuleAssembly { } """ let config = try assertParsesSyntaxTree(sourceFile) @@ -25,6 +25,7 @@ final class AssemblyParsingTests: XCTestCase { ] ) XCTAssertEqual(config.registrations.count, 0, "No registrations") + XCTAssertEqual(config.assemblyType, "ModuleAssembly") } func testDebugWrappedAssemblyImports() throws { @@ -108,6 +109,7 @@ final class AssemblyParsingTests: XCTestCase { let config = try assertParsesSyntaxTree(sourceFile) XCTAssertEqual(config.name, "FooTest") + XCTAssertEqual(config.assemblyType, "Assembly") } func testAssemblyStructModuleName() throws {