Skip to content

Commit

Permalink
Fix chicken/egg problem with ServerRoom and ServerRoomImpl
Browse files Browse the repository at this point in the history
To make the impl you need a room. To make a room you need an impl.
Instead, pass in the room to the impl so you do not need a room to
make an impl. This makes the way you call the impl a bit less elegant
as you refer to the room twice e.g `room.EventCreator(room, ...)` but
allows custom impls to use functional options.
  • Loading branch information
kegsay committed Mar 4, 2025
1 parent e0e4f42 commit 3569375
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 48 deletions.
6 changes: 3 additions & 3 deletions federation/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func MakeJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
// or dealing with HTTP responses itself.
func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeJoin, err error) {
// Generate a join event
proto, err := room.ProtoEventCreator(Event{
proto, err := room.ProtoEventCreator(room, Event{
Type: "m.room.member",
StateKey: &userID,
Content: map[string]interface{}{
Expand All @@ -84,7 +84,7 @@ func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.
// or dealing with HTTP responses itself.
func MakeRespMakeKnock(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeKnock, err error) {
// Generate a knock event
proto, err := room.ProtoEventCreator(Event{
proto, err := room.ProtoEventCreator(room, Event{
Type: "m.room.member",
StateKey: &userID,
Content: map[string]interface{}{
Expand Down Expand Up @@ -159,7 +159,7 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
return
}

resp := room.GenerateSendJoinResponse(s, event, expectPartialState, omitServersInRoom)
resp := room.GenerateSendJoinResponse(room, s, event, expectPartialState, omitServersInRoom)
b, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(500)
Expand Down
6 changes: 3 additions & 3 deletions federation/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ func (s *Server) DoFederationRequest(
// It does not insert this event into the room however. See ServerRoom.AddEvent for that.
func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) gomatrixserverlib.PDU {
t.Helper()
proto, err := room.ProtoEventCreator(ev)
proto, err := room.ProtoEventCreator(room, ev)
if err != nil {
ct.Fatalf(t, "MustCreateEvent: failed to create proto event: %v", err)
}
pdu, err := room.EventCreator(s, proto)
pdu, err := room.EventCreator(room, s, proto)
if err != nil {
ct.Fatalf(t, "MustCreateEvent: failed to create PDU: %v", err)
}
Expand Down Expand Up @@ -392,7 +392,7 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re
for _, opt := range jr.roomOpts {
opt(room)
}
room.PopulateFromSendJoinResponse(joinEvent, sendJoinResp)
room.PopulateFromSendJoinResponse(room, joinEvent, sendJoinResp)
s.rooms[roomID] = room

t.Logf("Server.MustJoinRoom joined room ID %s", roomID)
Expand Down
82 changes: 40 additions & 42 deletions federation/server_room.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func NewServerRoom(roomVer gomatrixserverlib.RoomVersion, roomId string) *Server
waiters: make(map[string][]*helpers.Waiter),
waitersMu: &sync.Mutex{},
}
room.ServerRoomImpl = &ServerRoomImplDefault{Room: room}
room.ServerRoomImpl = &ServerRoomImplDefault{}
return room
}

Expand Down Expand Up @@ -372,58 +372,56 @@ type ServerRoomImpl interface {
// ProtoEventCreator converts a Complement Event into a gomatrixserverlib proto event, ready to be signed.
// This function is used in /make_x endpoints to create proto events to return to other servers.
// This function is one of two used when creating events, the other being EventCreator.
ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error)
ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
// EventCreator converts a proto event into a signed PDU.
EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
// PopulateFromSendJoinResponse should replace the state of this ServerRoom with the information contained
// in RespSendJoin and the join event.
PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
// GenerateSendJoinResponse generates a /send_join response to send back to a server.
GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
}

type ServerRoomImplCustom struct {
ServerRoomImplDefault
ProtoEventCreatorFn func(def ServerRoomImpl, ev Event) (*gomatrixserverlib.ProtoEvent, error)
EventCreatorFn func(def ServerRoomImpl, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
GenerateSendJoinResponseFn func(def ServerRoomImpl, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
ProtoEventCreatorFn func(def ServerRoomImpl, room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
EventCreatorFn func(def ServerRoomImpl, room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
GenerateSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
}

func (i *ServerRoomImplCustom) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
func (i *ServerRoomImplCustom) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
if i.ProtoEventCreatorFn != nil {
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, ev)
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, room, ev)
}
return i.ServerRoomImplDefault.ProtoEventCreator(ev)
return i.ServerRoomImplDefault.ProtoEventCreator(room, ev)
}

func (i *ServerRoomImplCustom) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
func (i *ServerRoomImplCustom) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
if i.EventCreatorFn != nil {
return i.EventCreatorFn(&i.ServerRoomImplDefault, s, proto)
return i.EventCreatorFn(&i.ServerRoomImplDefault, room, s, proto)
}
return i.ServerRoomImplDefault.EventCreator(s, proto)
return i.ServerRoomImplDefault.EventCreator(room, s, proto)
}

func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
if i.PopulateFromSendJoinResponseFn != nil {
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, joinEvent, resp)
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, room, joinEvent, resp)
return
}
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(joinEvent, resp)
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(room, joinEvent, resp)
}

func (i *ServerRoomImplCustom) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
func (i *ServerRoomImplCustom) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
if i.GenerateSendJoinResponseFn != nil {
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, s, joinEvent, expectPartialState, omitServersInRoom)
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, room, s, joinEvent, expectPartialState, omitServersInRoom)
}
return i.ServerRoomImplDefault.GenerateSendJoinResponse(s, joinEvent, expectPartialState, omitServersInRoom)
return i.ServerRoomImplDefault.GenerateSendJoinResponse(room, s, joinEvent, expectPartialState, omitServersInRoom)
}

type ServerRoomImplDefault struct {
Room *ServerRoom
}
type ServerRoomImplDefault struct{}

func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
func (i *ServerRoomImplDefault) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
var prevEvents interface{}
if ev.PrevEvents != nil {
// We deliberately want to set the prev events.
Expand All @@ -432,14 +430,14 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
// No other prev events were supplied so we'll just
// use the forward extremities of the room, which is
// the usual behaviour.
prevEvents = i.Room.ForwardExtremities
prevEvents = room.ForwardExtremities
}
proto := gomatrixserverlib.ProtoEvent{
SenderID: ev.Sender,
Depth: int64(i.Room.Depth + 1), // depth starts at 1
Depth: int64(room.Depth + 1), // depth starts at 1
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: i.Room.RoomID,
RoomID: room.RoomID,
PrevEvents: prevEvents,
AuthEvents: ev.AuthEvents,
Redacts: ev.Redacts,
Expand All @@ -456,13 +454,13 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
if err != nil {
return nil, fmt.Errorf("EventCreator: failed to work out auth_events : %s", err)
}
proto.AuthEvents = i.Room.AuthEvents(stateNeeded)
proto.AuthEvents = room.AuthEvents(stateNeeded)
}
return &proto, nil
}

func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
verImpl, err := gomatrixserverlib.GetRoomVersion(i.Room.Version)
func (i *ServerRoomImplDefault) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
verImpl, err := gomatrixserverlib.GetRoomVersion(room.Version)
if err != nil {
return nil, fmt.Errorf("EventCreator: invalid room version: %s", err)
}
Expand All @@ -474,19 +472,19 @@ func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib
return signedEvent, nil
}

func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
stateEvents := resp.StateEvents.UntrustedEvents(i.Room.Version)
func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
stateEvents := resp.StateEvents.UntrustedEvents(room.Version)
for _, ev := range stateEvents {
i.Room.ReplaceCurrentState(ev)
room.ReplaceCurrentState(ev)
}
i.Room.AddEvent(joinEvent)
room.AddEvent(joinEvent)
}

func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
func (i *ServerRoomImplDefault) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
// build the state list *before* we insert the new event
var stateEvents []gomatrixserverlib.PDU
i.Room.StateMutex.RLock()
for _, ev := range i.Room.State {
room.StateMutex.RLock()
for _, ev := range room.State {
// filter out non-critical memberships if this is a partial-state join
if expectPartialState {
if ev.Type() == "m.room.member" && ev.StateKey() != joinEvent.StateKey() {
Expand All @@ -495,18 +493,18 @@ func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent go
}
stateEvents = append(stateEvents, ev)
}
i.Room.StateMutex.RUnlock()
room.StateMutex.RUnlock()

authEvents := i.Room.AuthChainForEvents(stateEvents)
authEvents := room.AuthChainForEvents(stateEvents)

// get servers in room *before* the join event
serversInRoom := []string{s.serverName}
if !omitServersInRoom {
serversInRoom = i.Room.ServersInRoom()
serversInRoom = room.ServersInRoom()
}

// insert the join event into the room state
i.Room.AddEvent(joinEvent)
room.AddEvent(joinEvent)
log.Printf("Received send-join of event %s", joinEvent.EventID())

// return state and auth chain
Expand Down

0 comments on commit 3569375

Please sign in to comment.