Skip to content

Commit

Permalink
[desktop] Clustering - Incorporate low quality face heuristics (#3123)
Browse files Browse the repository at this point in the history
  • Loading branch information
mnvr authored Sep 4, 2024
2 parents 0a1e062 + 485e844 commit 2bb1670
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 50 deletions.
47 changes: 31 additions & 16 deletions web/apps/photos/src/pages/cluster-debug.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import {
type ClusterDebugPageContents,
} from "@/new/photos/services/ml";
import {
type ClusterFace,
type ClusteringOpts,
type ClusteringProgress,
type FaceF32,
type OnClusteringProgress,
} from "@/new/photos/services/ml/cluster";
import { faceDirection } from "@/new/photos/services/ml/face";
Expand All @@ -22,6 +22,8 @@ import BackButton from "@mui/icons-material/ArrowBackOutlined";
import {
Box,
Button,
Checkbox,
FormControlLabel,
IconButton,
LinearProgress,
Stack,
Expand Down Expand Up @@ -68,10 +70,11 @@ export default function ClusterDebug() {
minBlur: 10,
minScore: 0.8,
minClusterSize: 2,
joinThreshold: 0.6,
joinThreshold: 0.76,
earlyExitThreshold: 0.9,
batchSize: 10000,
offsetIncrement: 7500,
badFaceHeuristics: true,
},
onSubmit: (values) =>
cluster(
Expand All @@ -83,6 +86,7 @@ export default function ClusterDebug() {
earlyExitThreshold: toFloat(values.earlyExitThreshold),
batchSize: toFloat(values.batchSize),
offsetIncrement: toFloat(values.offsetIncrement),
badFaceHeuristics: values.badFaceHeuristics,
},
(progress: ClusteringProgress) =>
onProgressRef.current?.(progress),
Expand Down Expand Up @@ -227,15 +231,30 @@ const MemoizedForm = memo(
onChange={handleChange}
/>
</Stack>
<Box marginInlineStart={"auto"} p={1}>
<Stack direction="row" justifyContent={"space-between"} p={1}>
<FormControlLabel
control={
<Checkbox
name={"badFaceHeuristics"}
checked={values.badFaceHeuristics}
size="small"
onChange={handleChange}
/>
}
label={
<Typography color="text.secondary">
Bad face heuristics
</Typography>
}
/>
<Button
color="secondary"
type="submit"
disabled={isSubmitting}
>
Cluster
</Button>
</Box>
</Stack>
</Stack>
</form>
),
Expand Down Expand Up @@ -325,7 +344,7 @@ const ClusterList: React.FC<React.PropsWithChildren<ClusterListProps>> = ({
index === 0
? 140
: index === 1
? 130
? 110
: Array.isArray(items[index - 2])
? listItemHeight
: 36;
Expand Down Expand Up @@ -447,15 +466,11 @@ const ClusterResHeader: React.FC<ClusterResHeaderProps> = ({ clusterRes }) => {
</Typography>
<Typography variant="small" color="text.muted">
For each cluster showing only up to 50 faces, sorted by cosine
similarity to highest scoring face in the cluster.
</Typography>
<Typography variant="small" color="text.muted">
Below each face is its{" "}
<b>blur - score - cosineSimilarity - direction</b>.
similarity to its highest scoring face.
</Typography>
<Typography variant="small" color="text.muted">
Faces added to the cluster as a result of next batch merging are
outlined.
Below each face is its blur, score, cosineSimilarity, direction.
Bad faces are outlined.
</Typography>
</Stack>
);
Expand Down Expand Up @@ -494,15 +509,15 @@ interface FaceItemProps {
}

interface FaceWithFile {
face: FaceF32;
face: ClusterFace;
enteFile: EnteFile;
cosineSimilarity?: number;
wasMerged?: boolean;
}

const FaceItem: React.FC<FaceItemProps> = ({ faceWithFile }) => {
const { face, enteFile, cosineSimilarity, wasMerged } = faceWithFile;
const { faceID } = face;
const { face, enteFile, cosineSimilarity } = faceWithFile;
const { faceID, isBadFace } = face;

const [objectURL, setObjectURL] = useState<string | undefined>();

Expand All @@ -526,7 +541,7 @@ const FaceItem: React.FC<FaceItemProps> = ({ faceWithFile }) => {
return (
<FaceChip
style={{
outline: wasMerged ? `1px solid gray` : undefined,
outline: isBadFace ? `1px solid rosybrown` : undefined,
outlineOffset: "2px",
}}
>
Expand Down
94 changes: 62 additions & 32 deletions web/packages/new/photos/services/ml/cluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { newNonSecureID } from "@/base/id-worker";
import log from "@/base/log";
import { ensure } from "@/utils/ensure";
import type { EnteFile } from "../../types/file";
import type { Face, FaceIndex } from "./face";
import { faceDirection, type Face, type FaceIndex } from "./face";
import { dotProduct } from "./math";

/**
Expand Down Expand Up @@ -121,6 +121,7 @@ export interface ClusteringOpts {
earlyExitThreshold: number;
batchSize: number;
offsetIncrement: number;
badFaceHeuristics: boolean;
}

export interface ClusteringProgress {
Expand All @@ -130,8 +131,10 @@ export interface ClusteringProgress {

export type OnClusteringProgress = (progress: ClusteringProgress) => void;

export type FaceF32 = Omit<Face, "embedding"> & {
/** A {@link Face} annotated with data needed during clustering. */
export type ClusterFace = Omit<Face, "embedding"> & {
embedding: Float32Array;
isBadFace: boolean;
};

export interface ClusterPreview {
Expand All @@ -140,33 +143,25 @@ export interface ClusterPreview {
}

export interface ClusterPreviewFace {
face: FaceF32;
face: ClusterFace;
cosineSimilarity: number;
wasMerged: boolean;
}

/**
* Cluster faces into groups.
*
* [Note: Face clustering algorithm]
*
* A cgroup (cluster group) consists of clusters, each of which itself is a set
* of faces.
*
* cgroup << cluster << face
*
* The clusters are generated locally by clients using the following algorithm:
*
* 1. clusters = [] initially, or fetched from remote.
*
* 2. For each face, find its nearest neighbour in the embedding space.
*
* 3. If no such neighbour is found within our threshold, create a new cluster.
* This function generates clusters locally using a batched form of linear
* clustering, with a bit of lookback (and a dollop of heuristics) to get the
* clusters to merge across batches.
*
* 4. Otherwise assign this face to the same cluster as its nearest neighbour.
*
* This user can then tweak the output of the algorithm by performing the
* following actions to the list of clusters that they can see:
* This user can later tweak these clusters by performing the following actions
* to the list of clusters that they can see:
*
* - They can provide a name for a cluster ("name a person"). This upgrades a
* cluster into a "cgroup", which is an entity that gets synced via remote
Expand Down Expand Up @@ -200,16 +195,14 @@ export const clusterFaces = (
earlyExitThreshold,
batchSize,
offsetIncrement,
badFaceHeuristics,
} = opts;
const t = Date.now();

const localFileByID = new Map(localFiles.map((f) => [f.id, f]));

// A flattened array of faces.
const allFaces = [...enumerateFaces(faceIndexes)];
const filteredFaces = allFaces
.filter((f) => f.blur > minBlur)
.filter((f) => f.score > minScore);
// A flattened array of filtered and annotated faces.
const filteredFaces = [...enumerateFaces(faceIndexes, minBlur, minScore)];

const fileForFaceID = new Map(
filteredFaces.map(({ faceID }) => [
Expand Down Expand Up @@ -264,6 +257,7 @@ export const clusterFaces = (
oldState,
joinThreshold,
earlyExitThreshold,
badFaceHeuristics,
({ completed }: ClusteringProgress) =>
onProgress({
completed: offset + completed,
Expand Down Expand Up @@ -335,7 +329,9 @@ export const clusterFaces = (
});
}

const totalFaceCount = allFaces.length;
// TODO-Cluster the total face count is only needed during debugging
let totalFaceCount = 0;
for (const fi of faceIndexes) totalFaceCount += fi.faces.length;
const filteredFaceCount = faces.length;
const clusteredFaceCount = clusterIDForFaceID.size;
const unclusteredFaceCount = filteredFaceCount - clusteredFaceCount;
Expand Down Expand Up @@ -364,20 +360,51 @@ export const clusterFaces = (
};

/**
* A generator function that returns a stream of {faceID, embedding} values,
* flattening all the the faces present in the given {@link faceIndices}.
* A generator function that returns a stream of eligible {@link ClusterFace}s
* by flattening all the the faces present in the given {@link faceIndices}.
*
* It also converts the embeddings to Float32Arrays to speed up the dot product
* calculations that will happen during clustering.
* During this, it also converts the embeddings to Float32Arrays to speed up the
* dot product calculations that will happen during clustering and attaches
* other information that the clustering algorithm needs.
*/
function* enumerateFaces(faceIndices: FaceIndex[]) {
function* enumerateFaces(
faceIndices: FaceIndex[],
minBlur: number,
minScore: number,
) {
for (const fi of faceIndices) {
for (const f of fi.faces) {
yield { ...f, embedding: new Float32Array(f.embedding) };
if (shouldIncludeFace(f, minBlur, minScore)) {
yield {
...f,
embedding: new Float32Array(f.embedding),
isBadFace: isBadFace(f),
};
}
}
}
}

/**
* Return true if the given face is above the minimum inclusion thresholds.
*/
const shouldIncludeFace = (face: Face, minBlur: number, minScore: number) =>
face.blur > minBlur && face.score > minScore;

/**
* Return true if the given face is above the minimum inclusion thresholds, but
* is otherwise heuristically determined to be possibly spurious face detection.
*
* We apply a higher threshold when clustering such faces.
*/
const isBadFace = (face: Face) =>
face.blur < 50 ||
(face.blur < 200 && face.blur < 0.85) ||
isSidewaysFace(face);

const isSidewaysFace = (face: Face) =>
faceDirection(face.detection) != "straight";

/** Generate a new cluster ID. */
const newClusterID = () => newNonSecureID("cluster_");

Expand All @@ -403,10 +430,11 @@ interface ClusteringState {
}

const clusterBatchLinear = (
faces: FaceF32[],
faces: ClusterFace[],
oldState: ClusteringState,
joinThreshold: number,
earlyExitThreshold: number,
badFaceHeuristics: boolean,
onProgress: (progress: ClusteringProgress) => void,
) => {
const state: ClusteringState = {
Expand All @@ -429,7 +457,7 @@ const clusterBatchLinear = (

// Find the nearest neighbour among the previous faces in this batch.
let nnIndex: number | undefined;
let nnCosineSimilarity = joinThreshold;
let nnCosineSimilarity = 0;
for (let j = i - 1; j >= 0; j--) {
// ! This is an O(n^2) loop, be careful when adding more code here.

Expand All @@ -439,13 +467,15 @@ const clusterBatchLinear = (
// The vectors are already normalized, so we can directly use their
// dot product as their cosine similarity.
const csim = dotProduct(fi.embedding, fj.embedding);
if (csim > nnCosineSimilarity) {
const threshold =
badFaceHeuristics && fj.isBadFace ? 0.84 : joinThreshold;
if (csim > nnCosineSimilarity && csim >= threshold) {
nnIndex = j;
nnCosineSimilarity = csim;

// If we've found something "near enough", stop looking for a
// better match (A heuristic to speed up clustering).
if (earlyExitThreshold > 0 && csim > earlyExitThreshold) break;
if (earlyExitThreshold > 0 && csim >= earlyExitThreshold) break;
}
}

Expand Down
4 changes: 2 additions & 2 deletions web/packages/new/photos/services/ml/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import { getRemoteFlag, updateRemoteFlag } from "../remote-store";
import type { SearchPerson } from "../search/types";
import type { UploadItem } from "../upload/types";
import {
type ClusterFace,
type ClusteringOpts,
type ClusterPreviewFace,
type FaceCluster,
type FaceF32,
type OnClusteringProgress,
} from "./cluster";
import { regenerateFaceCrops } from "./crop";
Expand Down Expand Up @@ -366,7 +366,7 @@ export interface ClusterDebugPageContents {
clusters: FaceCluster[];
clusterPreviewsWithFile: ClusterPreviewWithFile[];
unclusteredFacesWithFile: {
face: FaceF32;
face: ClusterFace;
enteFile: EnteFile;
}[];
}
Expand Down

0 comments on commit 2bb1670

Please sign in to comment.