diff --git a/src/mqtt_sessions.erl b/src/mqtt_sessions.erl
index b425005..2016b52 100644
--- a/src/mqtt_sessions.erl
+++ b/src/mqtt_sessions.erl
@@ -76,7 +76,7 @@
 
 -type session_ref() :: pid() | binary().
 -type msg_options() :: #{
-        transport => pid() | function(),
+        transport => transport(),
         peer_ip => tuple() | undefined,
         context_prefs => map(),
         connection_pid => pid()
@@ -94,6 +94,8 @@
 
 -type callback() :: pid() | {module(), atom(), list()}.
 
+-type transport() :: function() | callback().
+
 -export_type([
     session_ref/0,
     msg_options/0,
@@ -102,7 +104,8 @@
     subscriber/0,
     subscriber_options/0,
     topic/0,
-    callback/0
+    callback/0,
+    transport/0
 ]).
 
 -define(SIDEJOBS_PER_SESSION, 20).
@@ -188,7 +191,7 @@ update_user_context(Pool, ClientId, Fun) ->
         {error, _} = Error -> Error
     end.
 
--spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}.
+-spec get_transport( pid() ) -> {ok, transport()} | {error, notransport | noproc}.
 get_transport(SessionPid) ->
     mqtt_sessions_process:get_transport(SessionPid).
 
diff --git a/src/mqtt_sessions_process.erl b/src/mqtt_sessions_process.erl
index 9a20718..2a3effa 100644
--- a/src/mqtt_sessions_process.erl
+++ b/src/mqtt_sessions_process.erl
@@ -68,6 +68,7 @@
 
 -type packet_id() :: 0..65535.          % ?MAX_PACKET_ID
 
+
 -record(state, {
     protocol_version :: mqtt_packet_map:mqtt_version(),
     pool :: atom(),
@@ -75,7 +76,7 @@
     client_id :: binary(),
     routing_id :: binary(),
     user_context :: term(),
-    transport = undefined :: pid() | function() | undefined,
+    transport = undefined :: mqtt_sessions:transport() | undefined,
     connection_pid = undefined :: pid() | undefined,
     is_session_present = false :: boolean(),
     pending_connack = undefined :: term(),
@@ -145,7 +146,7 @@ update_user_context(Pid, Fun) ->
             {error, noproc}
     end.
 
--spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}.
+-spec get_transport( pid() ) -> {ok, mqtt_sessions:transport()} | {error, notransport | noproc}.
 get_transport(Pid) ->
     try
         gen_server:call(Pid, get_transport, infinity)
@@ -230,8 +231,8 @@ handle_call({update_user_context, Fun}, _From, #state{ user_context = UserContex
 
 handle_call(get_transport, _From, #state{ transport = undefined } = State) ->
     {reply, {error, notransport}, State};
-handle_call(get_transport, _From, #state{ transport = TransportPid } = State) ->
-    {reply, {ok, TransportPid}, State};
+handle_call(get_transport, _From, #state{ transport = Transport } = State) ->
+    {reply, {ok, Transport}, State};
 
 handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) ->
     Data1 = << Data/binary, NewData/binary >>,
@@ -1030,7 +1031,9 @@ send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) ->
             ok
     end;
 send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) ->
-    Fun(Msg).
+    Fun(Msg);
+send_transport(Msg, #state{ transport = {M, F, A} }) ->
+    erlang:apply(M, F, [Msg | A]).
 
 
 %% @doc Queue a message, extract, type, message expiry, and QoS