From 5e3b0b52d5427dba404b9cbe0f4cdf22729ea973 Mon Sep 17 00:00:00 2001 From: Dmitriy Zharov Date: Tue, 30 Apr 2024 16:09:38 +0700 Subject: [PATCH] Return value after storing SecDataConvertible --- Sources/SwiftSecurity/Keychain/Keychain.swift | 60 ++++++++++++++++++- .../Keychain/SecItemStore/SecItemStore.swift | 52 ++++++++-------- .../AccessPolicyTests.swift | 12 ++-- 3 files changed, 91 insertions(+), 33 deletions(-) diff --git a/Sources/SwiftSecurity/Keychain/Keychain.swift b/Sources/SwiftSecurity/Keychain/Keychain.swift index 374baaf..1d44d33 100644 --- a/Sources/SwiftSecurity/Keychain/Keychain.swift +++ b/Sources/SwiftSecurity/Keychain/Keychain.swift @@ -184,6 +184,15 @@ extension Keychain: SecDataStore { public func store(_ data: T, query: SecItemQuery, accessPolicy: AccessPolicy = .default) throws { try store(.data(data.rawRepresentation), query: query, accessPolicy: accessPolicy) } + + public func store( + _ data: T, + returning returnType: SecReturnType, + query: SecItemQuery, + accessPolicy: AccessPolicy = .default + ) throws -> SecValue? { + try store(.data(data.rawRepresentation), returning: returnType, query: query, accessPolicy: accessPolicy) + } public func retrieve(_ query: SecItemQuery, authenticationContext: LAContext? = nil) throws -> T? { if let value = try retrieve(.data, query: query, authenticationContext: authenticationContext), case let .data(data) = value { @@ -334,12 +343,31 @@ extension Keychain: SecIdentityStore { // MARK: - Private private extension Keychain { - func store(_ value: SecValue, query: SecItemQuery, accessPolicy: AccessPolicy) throws { + @discardableResult + func store( + _ value: SecValue, + returning returnType: SecReturnType = [], + query: SecItemQuery, + accessPolicy: AccessPolicy = .default + ) throws -> SecValue? { var query = query query[.accessGroup] = accessGroup.rawValue query[.accessControl] = try accessPolicy.accessControl query[.accessible] = accessPolicy.accessibility + if returnType.contains(.data) { + query[kSecReturnData as String] = true + } + if returnType.contains(.info) { + query[kSecReturnAttributes as String] = true + } + if returnType.contains(.reference) { + query[kSecReturnRef as String] = true + } + if returnType.contains(.persistentReference) { + query[kSecReturnPersistentRef as String] = true + } + switch value { case .data(let data): query[kSecValueData as String] = data @@ -349,9 +377,35 @@ private extension Keychain { throw SwiftSecurityError.invalidParameter } - switch SecItemAdd(query.rawValue as CFDictionary, nil) { + var result: AnyObject? + switch SecItemAdd(query.rawValue as CFDictionary, &result) { case errSecSuccess: - return + switch returnType { + case .data: + if let data = result as? Data { + return .data(data) + } else { + return nil + } + case .reference: + if let result { + return .reference(result) + } else { + return nil + } + case .persistentReference: + if let data = result as? Data { + return .persistentReference(data) + } else { + return nil + } + default: + if let attributes = result as? [String: Any] { + return .dictionary(SecItemInfo(rawValue: attributes)) + } else { + return nil + } + } case let status: throw SwiftSecurityError(rawValue: status) } diff --git a/Sources/SwiftSecurity/Keychain/SecItemStore/SecItemStore.swift b/Sources/SwiftSecurity/Keychain/SecItemStore/SecItemStore.swift index afd49b9..373df23 100644 --- a/Sources/SwiftSecurity/Keychain/SecItemStore/SecItemStore.swift +++ b/Sources/SwiftSecurity/Keychain/SecItemStore/SecItemStore.swift @@ -16,22 +16,46 @@ public protocol SecItemStore { func removeAll() throws } -// MARK: - SecData +public extension SecItemStore { + func info(for query: SecItemQuery, authenticationContext: LAContext? = nil) throws -> SecItemInfo? { + if let value = try retrieve(.info, query: query, authenticationContext: authenticationContext), case let .dictionary(info) = value { + return info + } else { + return nil + } + } +} + +// MARK: - Data public protocol SecDataStore: SecItemStore { - // MARK: - Generic + // MARK: - Generic Password - func store(_ data: T, query: SecItemQuery, accessPolicy: AccessPolicy) throws + func store(_ data: T, returning returnType: SecReturnType, query: SecItemQuery, accessPolicy: AccessPolicy) throws -> SecValue? func retrieve(_ query: SecItemQuery, authenticationContext: LAContext?) throws -> T? func remove(_ query: SecItemQuery) throws -> Bool - // MARK: - Internet + // MARK: - Internet Password func store(_ data: T, query: SecItemQuery, accessPolicy: AccessPolicy) throws func retrieve(_ query: SecItemQuery, authenticationContext: LAContext?) throws -> T? func remove(_ query: SecItemQuery) throws -> Bool } +public extension SecDataStore { + func store(_ data: T, query: SecItemQuery, accessPolicy: AccessPolicy) throws { + try self.store(data, returning: [], query: query, accessPolicy: accessPolicy) + } + + func retrieve(_ query: SecItemQuery) throws -> Data? { + try self.retrieve(query, authenticationContext: nil) + } + + func retrieve(_ query: SecItemQuery) throws -> Data? { + try self.retrieve(query, authenticationContext: nil) + } +} + // MARK: - SecKey public protocol SecKeyStore: SecItemStore { @@ -56,23 +80,3 @@ public protocol SecIdentityStore: SecItemStore { func retrieve(_ query: SecItemQuery, authenticationContext: LAContext?) throws -> SecIdentity? func remove(_ query: SecItemQuery) throws -> Bool } - -// MARK: - Convenient - -public extension SecDataStore { - func info(for query: SecItemQuery, authenticationContext: LAContext? = nil) throws -> SecItemInfo? { - if let value = try retrieve(.info, query: query, authenticationContext: authenticationContext), case let .dictionary(info) = value { - return info - } else { - return nil - } - } - - func retrieve(_ query: SecItemQuery) throws -> Data? { - try self.retrieve(query, authenticationContext: nil) - } - - func retrieve(_ query: SecItemQuery) throws -> Data? { - try self.retrieve(query, authenticationContext: nil) - } -} diff --git a/Tests/SwiftSecurityTests/AccessPolicyTests.swift b/Tests/SwiftSecurityTests/AccessPolicyTests.swift index e5d1ebb..a9d61bd 100644 --- a/Tests/SwiftSecurityTests/AccessPolicyTests.swift +++ b/Tests/SwiftSecurityTests/AccessPolicyTests.swift @@ -13,27 +13,27 @@ import Security final class AccessPolicyTests: XCTestCase { func testAccessibility() throws { do { - let accessPolicy = SecAccessPolicy(.whenPasscodeSetThisDeviceOnly) + let accessPolicy = AccessPolicy(.whenPasscodeSetThisDeviceOnly) XCTAssertEqual(accessPolicy.accessibility, String(kSecAttrAccessibleWhenPasscodeSetThisDeviceOnly)) XCTAssertNil(try accessPolicy.accessControl) } do { - let accessPolicy = SecAccessPolicy(.whenUnlocked) + let accessPolicy = AccessPolicy(.whenUnlocked) XCTAssertEqual(accessPolicy.accessibility, String(kSecAttrAccessibleWhenUnlocked)) XCTAssertNil(try accessPolicy.accessControl) } do { - let accessPolicy = SecAccessPolicy(.whenUnlockedThisDeviceOnly) + let accessPolicy = AccessPolicy(.whenUnlockedThisDeviceOnly) XCTAssertEqual(accessPolicy.accessibility, String(kSecAttrAccessibleWhenUnlockedThisDeviceOnly)) XCTAssertNil(try accessPolicy.accessControl) } do { - let accessPolicy = SecAccessPolicy(.afterFirstUnlock) + let accessPolicy = AccessPolicy(.afterFirstUnlock) XCTAssertEqual(accessPolicy.accessibility, String(kSecAttrAccessibleAfterFirstUnlock)) XCTAssertNil(try accessPolicy.accessControl) } do { - let accessPolicy = SecAccessPolicy(.afterFirstUnlockThisDeviceOnly) + let accessPolicy = AccessPolicy(.afterFirstUnlockThisDeviceOnly) XCTAssertEqual(accessPolicy.accessibility, String(kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly)) XCTAssertNil(try accessPolicy.accessControl) } @@ -41,7 +41,7 @@ final class AccessPolicyTests: XCTestCase { func testAccessControl() { do { - let accessPolicy = SecAccessPolicy(.afterFirstUnlock, options: .biometryAny) + let accessPolicy = AccessPolicy(.afterFirstUnlock, options: .biometryAny) XCTAssertNotNil(try accessPolicy.accessControl) } }