Skip to content

Commit

Permalink
MFA support (#169)
Browse files Browse the repository at this point in the history
* add MFA data types
* MFA client API
* implement MFA enrollment, verification, and auth
* provide MFA interaction sample
* add retrieval/generation of recovery codes
  • Loading branch information
ekoby authored Apr 20, 2021
1 parent 28cf8fd commit acb6986
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 27 deletions.
104 changes: 102 additions & 2 deletions samples/ziti-enroller/src/main/kotlin/org/openziti/ZitiEnroller.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 NetFoundry, Inc.
* Copyright (c) 2018-2021 NetFoundry, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,18 +17,31 @@
package org.openziti

import com.github.ajalt.clikt.core.CliktCommand
import com.github.ajalt.clikt.core.subcommands
import com.github.ajalt.clikt.parameters.options.flag
import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.options.required
import com.github.ajalt.clikt.parameters.options.validate
import com.github.ajalt.clikt.parameters.types.file
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.takeWhile
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.openziti.api.MFAType
import org.openziti.identity.Enroller
import java.io.File
import java.io.FileNotFoundException
import java.net.InetAddress
import java.security.KeyStore
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage

object ZitiEnroller {

private class Cli : CliktCommand(name = "ziti-enroller") {
private class enroll: CliktCommand() {
val jwt by option(help = "Enrollment token (JWT file). Required").file().required().validate {
it.exists() || throw FileNotFoundException("jwt[${it.path}] not found")
}
Expand All @@ -48,6 +61,93 @@ object ZitiEnroller {
}
}

private class verify: CliktCommand(help = "verify identity file. OTP will be requested if identity is enrolled in MFA") {
val idFile by option(help = "identity configuration file.").file().required().validate {
it.exists()
}

val showCodes by option(help = "display recovery codes").flag(default = false)
val newCodes by option(help = "generate new recovery codes").flag(default = false)

override fun run() {
val ztx = Ziti.newContext(idFile, charArrayOf(), object : Ziti.AuthHandler{
override fun getCode(ztx: ZitiContext, mfaType: MFAType, provider: String) =
CompletableFuture.supplyAsync {
print("Enter MFA code for $mfaType/$provider[${ztx.getId()?.name}]: ")
val code = readLine()
code!!
}
})

val j = GlobalScope.launch {
ztx.statusUpdates().collect {
println("status: $it")
when(it) {
ZitiContext.Status.Loading, ZitiContext.Status.Authenticating -> {}
ZitiContext.Status.Active -> {
println("verification success!")
cancel()
}
is ZitiContext.Status.NotAuthorized -> {
cancel("verification failed!", it.ex)
}
else -> cancel("unexpected status")
}
}
}

runBlocking {
j.join()

if (showCodes || newCodes) {
print("""enter OTP to ${if (newCodes) "generate" else "show"} recovery codes: """)
val code = readLine()
val recCodes = ztx.getMFARecoveryCodes(code!!, newCodes)
for (rc in recCodes) {
println(rc)
}
}

ztx.destroy()
}

}
}

private class mfa: CliktCommand(help = "Enroll identity in MFA") {
val idFile by option(help = "identity configuration file.").file().required().validate {
it.exists()
}

override fun run() {
val ztx = Ziti.newContext(idFile, charArrayOf())

val j = GlobalScope.launch {
ztx.statusUpdates().takeWhile { it != ZitiContext.Status.Active }.collectLatest { println(it) }
val mfa = ztx.enrollMFA()
println(mfa)

print("Enter OTP code: ")
val code = readLine()
ztx.verifyMFA(code!!.trim())
}
runBlocking { j.join() }

ztx.destroy()
}

}

private class Cli : CliktCommand(name = "ziti-enroller") {

init {
subcommands(enroll(), verify(), mfa())
}

override fun run() {
}
}

@JvmStatic
fun main(args: Array<String>) = Cli().main(args)
}
5 changes: 3 additions & 2 deletions ziti/src/main/kotlin/org/openziti/Exceptions.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 NetFoundry, Inc.
* Copyright (c) 2018-2021 NetFoundry, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,8 @@ private val errorMap = mapOf(
"REQUIRES_CERT_AUTH" to Errors.NotAuthorized,
"UNAUTHORIZED" to Errors.NotAuthorized,
"INVALID_AUTH" to Errors.NotAuthorized,
"INVALID_POSTURE" to Errors.InsufficientSecurity
"INVALID_POSTURE" to Errors.InsufficientSecurity,
"MFA_INVALID_TOKEN" to Errors.InsufficientSecurity
)

fun getZitiError(err: String): Errors = errorMap.getOrElse(err) { Errors.WTF(err) }
Expand Down
19 changes: 15 additions & 4 deletions ziti/src/main/kotlin/org/openziti/Ziti.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.openziti

import kotlinx.coroutines.flow.Flow
import org.openziti.api.MFAType
import org.openziti.api.Service
import org.openziti.impl.ZitiImpl
import org.openziti.net.ZitiSocketFactory
Expand All @@ -26,6 +27,7 @@ import org.openziti.net.nio.AsyncTLSSocketFactory
import java.io.File
import java.net.SocketAddress
import java.security.KeyStore
import java.util.concurrent.CompletionStage
import javax.net.SocketFactory
import javax.net.ssl.SSLSocketFactory

Expand All @@ -34,6 +36,11 @@ import javax.net.ssl.SSLSocketFactory
*/
object Ziti {

@FunctionalInterface
interface AuthHandler {
fun getCode(ztx: ZitiContext, mfaType: MFAType, provider: String): CompletionStage<String>
}

/**
* Load Ziti identity from the file.
* The following formats of ziti identity files are supported:
Expand All @@ -44,7 +51,8 @@ object Ziti {
* @param pwd password to access the file (only needed for .jks or .pfx/.p12 if they are protected by password)
*/
@JvmStatic
fun newContext(idFile: File, pwd: CharArray): ZitiContext = ZitiImpl.loadContext(idFile, pwd, null)
@JvmOverloads
fun newContext(idFile: File, pwd: CharArray, auth: AuthHandler? = null): ZitiContext = ZitiImpl.loadContext(idFile, pwd, null, auth)

/**
* Load Ziti identity from the file.
Expand All @@ -53,16 +61,19 @@ object Ziti {
* @param pwd password to access the file (only needed for .jks or .pfx/.p12 if they are protected by password)
*/
@JvmStatic
fun newContext(fname: String, pwd: CharArray): ZitiContext = newContext(File(fname), pwd)
@JvmOverloads
fun newContext(fname: String, pwd: CharArray, auth: AuthHandler? = null): ZitiContext = newContext(File(fname), pwd, auth)

@JvmStatic
fun removeContext(ctx: ZitiContext) = ZitiImpl.removeContext(ctx)

@JvmStatic
fun init(fname: String, pwd: CharArray, seamless: Boolean) = ZitiImpl.init(File(fname), pwd, seamless)
@JvmOverloads
fun init(fname: String, pwd: CharArray, seamless: Boolean, auth: AuthHandler? = null) = ZitiImpl.init(File(fname), pwd, seamless, auth)

@JvmStatic
fun init(ks: KeyStore, seamless: Boolean) = ZitiImpl.init(ks, seamless)
@JvmOverloads
fun init(ks: KeyStore, seamless: Boolean, auth: AuthHandler? = null) = ZitiImpl.init(ks, seamless, auth)

@JvmStatic
fun enroll(ks: KeyStore, jwt: ByteArray, name: String): ZitiContext = ZitiImpl.enroll(ks, jwt, name)
Expand Down
21 changes: 21 additions & 0 deletions ziti/src/main/kotlin/org/openziti/ZitiContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

package org.openziti

import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.async
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.future.asCompletableFuture
import org.openziti.api.MFAEnrollment
import org.openziti.api.Service
import org.openziti.api.ServiceTerminator
import org.openziti.identity.Identity
import java.net.InetSocketAddress
import java.net.Socket
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.util.concurrent.Future
import org.openziti.api.Identity as ApiIdentity

/**
Expand All @@ -44,6 +49,7 @@ interface ZitiContext: Identity {

sealed class Status {
object Loading: Status()
object Authenticating: Status()
object Active: Status()
object Disabled: Status()
class NotAuthorized(val ex: Throwable): Status()
Expand Down Expand Up @@ -112,4 +118,19 @@ interface ZitiContext: Identity {

fun destroy()

suspend fun enrollMFA(): MFAEnrollment
fun enrollMFAAsync() = GlobalScope.async {
enrollMFA()
}.asCompletableFuture()

suspend fun verifyMFA(code: String)
fun verifyMFAAsync(code: String) = GlobalScope.async { verifyMFA(code) }.asCompletableFuture()

suspend fun removeMFA(code: String)
fun removeMFAAsync(code: String) =
GlobalScope.async { removeMFA(code) }.asCompletableFuture()

suspend fun getMFARecoveryCodes(code: String, newCodes: Boolean): Array<String>
fun getMFARecoveryCodesAsync(code: String, newCodes: Boolean) =
GlobalScope.async { getMFARecoveryCodes(code, newCodes) }.asCompletableFuture()
}
58 changes: 55 additions & 3 deletions ziti/src/main/kotlin/org/openziti/api/Controller.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import okhttp3.OkHttpClient
import okhttp3.ResponseBody
import okhttp3.logging.HttpLoggingInterceptor
import org.openziti.Errors
import org.openziti.ZitiContext
import org.openziti.ZitiException
import org.openziti.getZitiError
import org.openziti.impl.ZitiImpl
Expand All @@ -46,7 +45,6 @@ import retrofit2.converter.gson.GsonConverterFactory
import retrofit2.http.*
import java.io.IOException
import java.net.URL
import java.time.Instant
import java.util.*
import javax.net.ssl.SSLContext
import javax.net.ssl.X509TrustManager
Expand All @@ -71,6 +69,27 @@ internal class Controller(endpoint: URL, sslContext: SSLContext, trustManager: X
@DELETE("current-api-session")
fun logout(): Deferred<Unit>

@GET("/current-identity/mfa")
fun getMFA(): Deferred<Response<MFAEnrollment>>

@POST("/current-identity/mfa")
fun postMFA(): Deferred<Response<Unit>>

@DELETE("/current-identity/mfa")
fun removeMFA(@Header("mfa-validation-code") code: String): Deferred<Response<Unit>>

@POST("/authenticate/mfa")
fun authMFA(@Body code: MFACode): Deferred<Response<Unit>>

@POST("/current-identity/mfa/verify")
fun verifyMFA(@Body code: MFACode): Deferred<Response<Unit>>

@GET("/current-identity/mfa/recovery-codes")
fun getMFACodes(@Header("mfa-validation-code") code: String): Deferred<Response<MFARecoveryCodes>>

@POST("/current-identity/mfa/recovery-codes")
fun newMFACodes(@Body code: MFACode): Deferred<Response<Unit>>

@GET("/current-api-session/service-updates")
fun getServiceUpdates(): Deferred<Response<ServiceUpdates>>

Expand Down Expand Up @@ -216,6 +235,40 @@ internal class Controller(endpoint: URL, sslContext: SSLContext, trustManager: X
offset -> api.getServiceTerminators(s.id, offset = offset)
}

internal suspend fun postMFA() = api.postMFA().await().data

internal suspend fun getMFAEnrollment(): MFAEnrollment? =
runCatching { api.getMFA().await().data }
.onFailure {
if (it !is HttpException) throw it
if (it.code() != 404) throw it
}.getOrNull()

internal suspend fun verifyMFA(code: String) {
runCatching {
api.verifyMFA(MFACode(code)).await()
}.getOrElse { convertError(it) }
}

internal suspend fun authMFA(code: String) {
runCatching { api.authMFA(MFACode(code)).await() }
.getOrElse { convertError(it) }
}

internal suspend fun removeMFA(code: String) {
runCatching { api.removeMFA(code).await() }
.getOrElse { convertError(it) }
}

internal suspend fun getMFARecoveryCodes(code: String, newCodes: Boolean): Array<String> {
if (newCodes) {
api.newMFACodes(MFACode(code)).await()
}

val codes = api.getMFACodes(code)
return codes.await().data?.recoveryCodes ?: emptyArray()
}

private fun <T> pagingRequest(req: (offset: Int) -> Deferred<Response<Collection<T>>>) = flow {
var offset = 0

Expand All @@ -242,7 +295,6 @@ internal class Controller(endpoint: URL, sslContext: SSLContext, trustManager: X
}

private fun convertError(t: Throwable): Nothing {
e("error: ${t.localizedMessage}")
val errCode = when (t) {
is HttpException -> getZitiError(getError(t.response()))
is IOException -> Errors.ControllerUnavailable
Expand Down
Loading

0 comments on commit acb6986

Please sign in to comment.