diff --git a/src/mqtt_sessions.erl b/src/mqtt_sessions.erl index b425005..2016b52 100644 --- a/src/mqtt_sessions.erl +++ b/src/mqtt_sessions.erl @@ -76,7 +76,7 @@ -type session_ref() :: pid() | binary(). -type msg_options() :: #{ - transport => pid() | function(), + transport => transport(), peer_ip => tuple() | undefined, context_prefs => map(), connection_pid => pid() @@ -94,6 +94,8 @@ -type callback() :: pid() | {module(), atom(), list()}. +-type transport() :: function() | callback(). + -export_type([ session_ref/0, msg_options/0, @@ -102,7 +104,8 @@ subscriber/0, subscriber_options/0, topic/0, - callback/0 + callback/0, + transport/0 ]). -define(SIDEJOBS_PER_SESSION, 20). @@ -188,7 +191,7 @@ update_user_context(Pool, ClientId, Fun) -> {error, _} = Error -> Error end. --spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}. +-spec get_transport( pid() ) -> {ok, transport()} | {error, notransport | noproc}. get_transport(SessionPid) -> mqtt_sessions_process:get_transport(SessionPid). diff --git a/src/mqtt_sessions_process.erl b/src/mqtt_sessions_process.erl index 9a20718..2a3effa 100644 --- a/src/mqtt_sessions_process.erl +++ b/src/mqtt_sessions_process.erl @@ -68,6 +68,7 @@ -type packet_id() :: 0..65535. % ?MAX_PACKET_ID + -record(state, { protocol_version :: mqtt_packet_map:mqtt_version(), pool :: atom(), @@ -75,7 +76,7 @@ client_id :: binary(), routing_id :: binary(), user_context :: term(), - transport = undefined :: pid() | function() | undefined, + transport = undefined :: mqtt_sessions:transport() | undefined, connection_pid = undefined :: pid() | undefined, is_session_present = false :: boolean(), pending_connack = undefined :: term(), @@ -145,7 +146,7 @@ update_user_context(Pid, Fun) -> {error, noproc} end. --spec get_transport( pid() ) -> {ok, pid()} | {error, notransport | noproc}. +-spec get_transport( pid() ) -> {ok, mqtt_sessions:transport()} | {error, notransport | noproc}. get_transport(Pid) -> try gen_server:call(Pid, get_transport, infinity) @@ -230,8 +231,8 @@ handle_call({update_user_context, Fun}, _From, #state{ user_context = UserContex handle_call(get_transport, _From, #state{ transport = undefined } = State) -> {reply, {error, notransport}, State}; -handle_call(get_transport, _From, #state{ transport = TransportPid } = State) -> - {reply, {ok, TransportPid}, State}; +handle_call(get_transport, _From, #state{ transport = Transport } = State) -> + {reply, {ok, Transport}, State}; handle_call({incoming_data, NewData, ConnectionPid}, _From, #state{ incoming_data = Data, connection_pid = ConnectionPid } = State) -> Data1 = << Data/binary, NewData/binary >>, @@ -1030,7 +1031,9 @@ send_transport(Msg, #state{ transport = Pid }) when is_pid(Pid) -> ok end; send_transport(Msg, #state{ transport = Fun }) when is_function(Fun) -> - Fun(Msg). + Fun(Msg); +send_transport(Msg, #state{ transport = {M, F, A} }) -> + erlang:apply(M, F, [Msg | A]). %% @doc Queue a message, extract, type, message expiry, and QoS