Skip to content

Commit

Permalink
custom json coder support (#189)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanner0101 authored Jul 28, 2020
1 parent 3a4b45b commit 290ff55
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 18 deletions.
2 changes: 2 additions & 0 deletions Sources/FluentMySQLDriver/Exports.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@_exported import struct MySQLKit.MySQLConfiguration
@_exported import struct MySQLKit.MySQLConnectionSource
@_exported import struct MySQLKit.MySQLDataEncoder
@_exported import struct MySQLKit.MySQLDataDecoder

@_exported import class MySQLNIO.MySQLConnection
@_exported import enum MySQLNIO.MySQLError
Expand Down
48 changes: 37 additions & 11 deletions Sources/FluentMySQLDriver/FluentMySQLConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ extension DatabaseConfigurationFactory {
password: String,
database: String? = nil,
maxConnectionsPerEventLoop: Int = 1,
connectionPoolTimeout: NIO.TimeAmount = .seconds(10)
connectionPoolTimeout: NIO.TimeAmount = .seconds(10),
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) throws -> Self {
let configuration = MySQLConfiguration(
unixDomainSocketPath: unixDomainSocketPath,
Expand All @@ -19,36 +21,46 @@ extension DatabaseConfigurationFactory {
return .mysql(
configuration: configuration,
maxConnectionsPerEventLoop: maxConnectionsPerEventLoop,
connectionPoolTimeout: connectionPoolTimeout
connectionPoolTimeout: connectionPoolTimeout,
encoder: encoder,
decoder: decoder
)
}
public static func mysql(
url urlString: String,
maxConnectionsPerEventLoop: Int = 1,
connectionPoolTimeout: NIO.TimeAmount = .seconds(10)
connectionPoolTimeout: NIO.TimeAmount = .seconds(10),
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) throws -> Self {
guard let url = URL(string: urlString) else {
throw FluentMySQLError.invalidURL(urlString)
}
return try self.mysql(
url: url,
url: url,
maxConnectionsPerEventLoop: maxConnectionsPerEventLoop,
connectionPoolTimeout: connectionPoolTimeout
connectionPoolTimeout: connectionPoolTimeout,
encoder: encoder,
decoder: decoder
)
}

public static func mysql(
url: URL,
maxConnectionsPerEventLoop: Int = 1,
connectionPoolTimeout: NIO.TimeAmount = .seconds(10)
connectionPoolTimeout: NIO.TimeAmount = .seconds(10),
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) throws -> Self {
guard let configuration = MySQLConfiguration(url: url) else {
throw FluentMySQLError.invalidURL(url.absoluteString)
}
return .mysql(
configuration: configuration,
maxConnectionsPerEventLoop: maxConnectionsPerEventLoop,
connectionPoolTimeout: connectionPoolTimeout
connectionPoolTimeout: connectionPoolTimeout,
encoder: encoder,
decoder: decoder
)
}

Expand All @@ -60,7 +72,9 @@ extension DatabaseConfigurationFactory {
database: String? = nil,
tlsConfiguration: TLSConfiguration? = .forClient(),
maxConnectionsPerEventLoop: Int = 1,
connectionPoolTimeout: NIO.TimeAmount = .seconds(10)
connectionPoolTimeout: NIO.TimeAmount = .seconds(10),
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) -> Self {
return .mysql(
configuration: .init(
Expand All @@ -72,20 +86,26 @@ extension DatabaseConfigurationFactory {
tlsConfiguration: tlsConfiguration
),
maxConnectionsPerEventLoop: maxConnectionsPerEventLoop,
connectionPoolTimeout: connectionPoolTimeout
connectionPoolTimeout: connectionPoolTimeout,
encoder: encoder,
decoder: decoder
)
}

public static func mysql(
configuration: MySQLConfiguration,
maxConnectionsPerEventLoop: Int = 1,
connectionPoolTimeout: NIO.TimeAmount = .seconds(10)
connectionPoolTimeout: NIO.TimeAmount = .seconds(10),
encoder: MySQLDataEncoder = .init(),
decoder: MySQLDataDecoder = .init()
) -> Self {
return Self {
FluentMySQLConfiguration(
configuration: configuration,
maxConnectionsPerEventLoop: maxConnectionsPerEventLoop,
connectionPoolTimeout: connectionPoolTimeout,
encoder: encoder,
decoder: decoder,
middleware: []
)
}
Expand All @@ -96,6 +116,8 @@ struct FluentMySQLConfiguration: DatabaseConfiguration {
let configuration: MySQLConfiguration
let maxConnectionsPerEventLoop: Int
let connectionPoolTimeout: TimeAmount
let encoder: MySQLDataEncoder
let decoder: MySQLDataDecoder
var middleware: [AnyModelMiddleware]

func makeDriver(for databases: Databases) -> DatabaseDriver {
Expand All @@ -108,6 +130,10 @@ struct FluentMySQLConfiguration: DatabaseConfiguration {
requestTimeout: self.connectionPoolTimeout,
on: databases.eventLoopGroup
)
return _FluentMySQLDriver(pool: pool)
return _FluentMySQLDriver(
pool: pool,
encoder: self.encoder,
decoder: self.decoder
)
}
}
22 changes: 18 additions & 4 deletions Sources/FluentMySQLDriver/FluentMySQLDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import AsyncKit

struct _FluentMySQLDatabase {
let database: MySQLDatabase
let encoder: MySQLDataEncoder
let decoder: MySQLDataDecoder
let context: DatabaseContext
let inTransaction: Bool
}
Expand All @@ -18,9 +20,9 @@ extension _FluentMySQLDatabase: Database {
let (sql, binds) = self.serialize(expression)
do {
return try self.query(
sql, binds.map { try MySQLDataEncoder().encode($0) },
sql, binds.map { try self.encoder.encode($0) },
onRow: { row in
onOutput(row.databaseOutput())
onOutput(row.databaseOutput(decoder: self.decoder))
},
onMetadata: { metadata in
switch query.action {
Expand Down Expand Up @@ -62,7 +64,13 @@ extension _FluentMySQLDatabase: Database {
}
return self.database.withConnection { conn in
conn.simpleQuery("START TRANSACTION").flatMap { _ in
let db = _FluentMySQLDatabase(database: conn, context: self.context, inTransaction: true)
let db = _FluentMySQLDatabase(
database: conn,
encoder: self.encoder,
decoder: self.decoder,
context: self.context,
inTransaction: true
)
return closure(db).flatMap { result in
conn.simpleQuery("COMMIT").map { _ in
result
Expand All @@ -78,7 +86,13 @@ extension _FluentMySQLDatabase: Database {

func withConnection<T>(_ closure: @escaping (Database) -> EventLoopFuture<T>) -> EventLoopFuture<T> {
self.database.withConnection {
closure(_FluentMySQLDatabase(database: $0, context: self.context, inTransaction: self.inTransaction))
closure(_FluentMySQLDatabase(
database: $0,
encoder: self.encoder,
decoder: self.decoder,
context: self.context,
inTransaction: self.inTransaction
))
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions Sources/FluentMySQLDriver/FluentMySQLDriver.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import AsyncKit

struct _FluentMySQLDriver: DatabaseDriver {
let pool: EventLoopGroupConnectionPool<MySQLConnectionSource>
let encoder: MySQLDataEncoder
let decoder: MySQLDataDecoder

var eventLoopGroup: EventLoopGroup {
self.pool.eventLoopGroup
Expand All @@ -10,6 +12,8 @@ struct _FluentMySQLDriver: DatabaseDriver {
func makeDatabase(with context: DatabaseContext) -> Database {
_FluentMySQLDatabase(
database: self.pool.pool(for: context.eventLoop).database(logger: context.logger),
encoder: self.encoder,
decoder: self.decoder,
context: context,
inTransaction: false
)
Expand Down
10 changes: 7 additions & 3 deletions Sources/FluentMySQLDriver/MySQLRow+Database.swift
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
extension MySQLRow {
internal func databaseOutput() -> DatabaseOutput {
_MySQLDatabaseOutput(row: self, schema: nil)
internal func databaseOutput(decoder: MySQLDataDecoder) -> DatabaseOutput {
_MySQLDatabaseOutput(row: self, decoder: decoder, schema: nil)
}
}

private struct _MySQLDatabaseOutput: DatabaseOutput {
let row: MySQLRow
let decoder: MySQLDataDecoder
let schema: String?

var description: String {
Expand All @@ -27,14 +28,17 @@ private struct _MySQLDatabaseOutput: DatabaseOutput {
func schema(_ schema: String) -> DatabaseOutput {
_MySQLDatabaseOutput(
row: self.row,
decoder: self.decoder,
schema: schema
)
}

func decode<T>(_ key: FieldKey, as type: T.Type) throws -> T
where T: Decodable
{
try self.row.decode(column: self.columnName(key), as: T.self)
try self.row
.sql(decoder: self.decoder)
.decode(column: self.columnName(key), as: T.self)
}


Expand Down

0 comments on commit 290ff55

Please sign in to comment.