Skip to content

Commit

Permalink
Resend MatrixRTC encryption keys if a membership has changed (#4343)
Browse files Browse the repository at this point in the history
* Resend MatrixRTC encryption keys if a membership has changed

* JSDoc

* Update src/matrixrtc/MatrixRTCSession.ts

Co-authored-by: Andrew Ferrazzutti <[email protected]>

* Add note about using Set. symmetricDifference() when available

* Always store latest fingerprints

Should reduce unnecessary retransmits

* Refactor

---------

Co-authored-by: Andrew Ferrazzutti <[email protected]>
  • Loading branch information
hughns and AndrewFerr authored Aug 14, 2024
1 parent 78cbf7c commit c65ef03
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 9 deletions.
218 changes: 218 additions & 0 deletions spec/unit/matrixrtc/MatrixRTCSession.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,17 @@ describe("MatrixRTCSession", () => {
});
});

it("does not send key if join called when already joined", () => {
sess!.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });

expect(client.sendStateEvent).toHaveBeenCalledTimes(1);
expect(client.sendEvent).toHaveBeenCalledTimes(1);

sess!.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });
expect(client.sendStateEvent).toHaveBeenCalledTimes(1);
expect(client.sendEvent).toHaveBeenCalledTimes(1);
});

it("retries key sends", async () => {
jest.useFakeTimers();
let firstEventSent = false;
Expand Down Expand Up @@ -685,6 +696,213 @@ describe("MatrixRTCSession", () => {
}
});

it("Does not re-send key if memberships stays same", async () => {
jest.useFakeTimers();
try {
const keysSentPromise1 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

const member1 = membershipTemplate;
const member2 = Object.assign({}, membershipTemplate, {
device_id: "BBBBBBB",
});

const mockRoom = makeMockRoom([member1, member2]);
mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([member1, member2], mockRoom.roomId, undefined));

sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });

await keysSentPromise1;

// make sure an encryption key was sent
expect(sendEventMock).toHaveBeenCalledWith(
expect.stringMatching(".*"),
"io.element.call.encryption_keys",
{
call_id: "",
device_id: "AAAAAAA",
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
},
],
},
);

sendEventMock.mockClear();

// these should be a no-op:
sess.onMembershipUpdate();
expect(sendEventMock).toHaveBeenCalledTimes(0);
} finally {
jest.useRealTimers();
}
});

it("Re-sends key if a member changes membership ID", async () => {
jest.useFakeTimers();
try {
const keysSentPromise1 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

const member1 = membershipTemplate;
const member2 = {
...membershipTemplate,
device_id: "BBBBBBB",
};

const mockRoom = makeMockRoom([member1, member2]);
mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([member1, member2], mockRoom.roomId, undefined));

sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });

await keysSentPromise1;

// make sure an encryption key was sent
expect(sendEventMock).toHaveBeenCalledWith(
expect.stringMatching(".*"),
"io.element.call.encryption_keys",
{
call_id: "",
device_id: "AAAAAAA",
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
},
],
},
);

sendEventMock.mockClear();

// this should be a no-op:
sess.onMembershipUpdate();
expect(sendEventMock).toHaveBeenCalledTimes(0);

// advance time to avoid key throttling
jest.advanceTimersByTime(10000);

// update membership ID
member2.membershipID = "newID";

const keysSentPromise2 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

// this should re-send the key
sess.onMembershipUpdate();

await keysSentPromise2;

expect(sendEventMock).toHaveBeenCalledWith(
expect.stringMatching(".*"),
"io.element.call.encryption_keys",
{
call_id: "",
device_id: "AAAAAAA",
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
},
],
},
);
} finally {
jest.useRealTimers();
}
});

it("Re-sends key if a member changes created_ts", async () => {
jest.useFakeTimers();
try {
const keysSentPromise1 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

const member1 = { ...membershipTemplate, created_ts: 1000 };
const member2 = {
...membershipTemplate,
created_ts: 1000,
device_id: "BBBBBBB",
};

const mockRoom = makeMockRoom([member1, member2]);
mockRoom.getLiveTimeline().getState = jest
.fn()
.mockReturnValue(makeMockRoomState([member1, member2], mockRoom.roomId, undefined));

sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.joinRoomSession([mockFocus], mockFocus, { manageMediaKeys: true });

await keysSentPromise1;

// make sure an encryption key was sent
expect(sendEventMock).toHaveBeenCalledWith(
expect.stringMatching(".*"),
"io.element.call.encryption_keys",
{
call_id: "",
device_id: "AAAAAAA",
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
},
],
},
);

sendEventMock.mockClear();

// this should be a no-op:
sess.onMembershipUpdate();
expect(sendEventMock).toHaveBeenCalledTimes(0);

// advance time to avoid key throttling
jest.advanceTimersByTime(10000);

// update created_ts
member2.created_ts = 5000;

const keysSentPromise2 = new Promise((resolve) => {
sendEventMock.mockImplementation(resolve);
});

// this should re-send the key
sess.onMembershipUpdate();

await keysSentPromise2;

expect(sendEventMock).toHaveBeenCalledWith(
expect.stringMatching(".*"),
"io.element.call.encryption_keys",
{
call_id: "",
device_id: "AAAAAAA",
keys: [
{
index: 0,
key: expect.stringMatching(".*"),
},
],
},
);
} finally {
jest.useRealTimers();
}
});

it("Rotates key if a member leaves", async () => {
jest.useFakeTimers();
try {
Expand Down
54 changes: 45 additions & 9 deletions src/matrixrtc/MatrixRTCSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
private encryptionKeys = new Map<string, Array<Uint8Array>>();
private lastEncryptionKeyUpdateRequest?: number;

// We use this to store the last membership fingerprints we saw, so we can proactively re-send encryption keys
// if it looks like a membership has been updated.
private lastMembershipFingerprints: Set<string> | undefined;

/**
* The callId (sessionId) of the call.
*
Expand Down Expand Up @@ -636,6 +640,14 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
}
};

private isMyMembership = (m: CallMembership): boolean =>
m.sender === this.client.getUserId() && m.deviceId === this.client.getDeviceId();

/**
* Examines the latest call memberships and handles any encryption key sending or rotation that is needed.
*
* This function should be called when the room members or call memberships might have changed.
*/
public onMembershipUpdate = (): void => {
const oldMemberships = this.memberships;
this.memberships = MatrixRTCSession.callMembershipsForRoom(this.room);
Expand All @@ -651,32 +663,56 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
this.emit(MatrixRTCSessionEvent.MembershipsChanged, oldMemberships, this.memberships);
}

const isMyMembership = (m: CallMembership): boolean =>
m.sender === this.client.getUserId() && m.deviceId === this.client.getDeviceId();

if (this.manageMediaKeys && this.isJoined() && this.makeNewKeyTimeout === undefined) {
const oldMebershipIds = new Set(
oldMemberships.filter((m) => !isMyMembership(m)).map(getParticipantIdFromMembership),
const oldMembershipIds = new Set(
oldMemberships.filter((m) => !this.isMyMembership(m)).map(getParticipantIdFromMembership),
);
const newMebershipIds = new Set(
this.memberships.filter((m) => !isMyMembership(m)).map(getParticipantIdFromMembership),
const newMembershipIds = new Set(
this.memberships.filter((m) => !this.isMyMembership(m)).map(getParticipantIdFromMembership),
);

const anyLeft = Array.from(oldMebershipIds).some((x) => !newMebershipIds.has(x));
const anyJoined = Array.from(newMebershipIds).some((x) => !oldMebershipIds.has(x));
// We can use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/symmetricDifference
// for this once available
const anyLeft = Array.from(oldMembershipIds).some((x) => !newMembershipIds.has(x));
const anyJoined = Array.from(newMembershipIds).some((x) => !oldMembershipIds.has(x));

const oldFingerprints = this.lastMembershipFingerprints;
// always store the fingerprints of these latest memberships
this.storeLastMembershipFingerprints();

if (anyLeft) {
logger.debug(`Member(s) have left: queueing sender key rotation`);
this.makeNewKeyTimeout = setTimeout(this.onRotateKeyTimeout, MAKE_KEY_DELAY);
} else if (anyJoined) {
logger.debug(`New member(s) have joined: re-sending keys`);
this.requestKeyEventSend();
} else if (oldFingerprints) {
// does it look like any of the members have updated their memberships?
const newFingerprints = this.lastMembershipFingerprints!;

// We can use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/symmetricDifference
// for this once available
const candidateUpdates =
Array.from(oldFingerprints).some((x) => !newFingerprints.has(x)) ||
Array.from(newFingerprints).some((x) => !oldFingerprints.has(x));
if (candidateUpdates) {
logger.debug(`Member(s) have updated/reconnected: re-sending keys`);
this.requestKeyEventSend();
}
}
}

this.setExpiryTimer();
};

private storeLastMembershipFingerprints(): void {
this.lastMembershipFingerprints = new Set(
this.memberships
.filter((m) => !this.isMyMembership(m))
.map((m) => `${getParticipantIdFromMembership(m)}:${m.membershipID}:${m.createdTs()}`),
);
}

/**
* Constructs our own membership
* @param prevMembership - The previous value of our call membership, if any
Expand Down

0 comments on commit c65ef03

Please sign in to comment.