Skip to content

Commit

Permalink
[mob][photos] Cache results for Magic section (only visible to intern…
Browse files Browse the repository at this point in the history
…al users) (#2282)

## Description

- Use cached results for magic section so that it does't anymore have to
wait for ML framework to be initialised and doesn't have to re-compute
results every time, which means faster loading of the search tab.
- For internal users, all results in
[here](https://discover.ente.io/v1.json) will show up.
- For non-internal users, once available, results will be limited to 4.
- 4 random prompts are selected from
[here](https://discover.ente.io/v1.json) with non-empty results and are
cached.
- The cache updates when the data updates
[here](https://discover.ente.io/v1.json) (checks size to compare) or in
3 days since the last update.
  • Loading branch information
ashilkn authored Jun 25, 2024
2 parents 2f327b1 + 75dcf18 commit f67dc48
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 65 deletions.
3 changes: 3 additions & 0 deletions mobile/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d
import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart';
import "package:photos/services/machine_learning/machine_learning_controller.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/services/magic_cache_service.dart";
import 'package:photos/services/memories_service.dart';
import 'package:photos/services/push_service.dart';
import 'package:photos/services/remote_sync_service.dart';
Expand Down Expand Up @@ -303,6 +304,8 @@ Future<void> _init(bool isBackground, {String via = ''}) async {
preferences,
);

MagicCacheService.instance.init(preferences);

initComplete = true;
_logger.info("Initialization done");
} catch (e, s) {
Expand Down
3 changes: 0 additions & 3 deletions mobile/lib/models/search/search_types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import "package:photos/models/collection/collection_items.dart";
import "package:photos/models/search/search_result.dart";
import "package:photos/models/typedefs.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart";
import "package:photos/services/search_service.dart";
import "package:photos/ui/viewer/gallery/collection_page.dart";
import "package:photos/ui/viewer/location/add_location_sheet.dart";
Expand Down Expand Up @@ -292,8 +291,6 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.location:
return [Bus.instance.on<LocationTagUpdatedEvent>()];
case SectionType.magic:
return [Bus.instance.on<MLFrameworkInitializationUpdateEvent>()];
default:
return [];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,49 @@ class SemanticSearchService {
return results;
}

Future<List<int>> getMatchingFileIDs(String query, double minScore) async {
final textEmbedding = await _getTextEmbedding(query);

final queryResults =
await _getScores(textEmbedding, scoreThreshold: minScore);

final queryResultIds = <int>[];
for (QueryResult result in queryResults) {
queryResultIds.add(result.id);
}

final filesMap = await FilesDB.instance.getFilesFromIDs(
queryResultIds,
);
final results = <EnteFile>[];

final ignoredCollections =
CollectionsService.instance.getHiddenCollectionIds();
final deletedEntries = <int>[];
for (final result in queryResults) {
final file = filesMap[result.id];
if (file != null && !ignoredCollections.contains(file.collectionID)) {
results.add(file);
}
if (file == null) {
deletedEntries.add(result.id);
}
}

_logger.info(results.length.toString() + " results");

if (deletedEntries.isNotEmpty) {
unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries));
}

final matchingFileIDs = <int>[];
for (EnteFile file in results) {
matchingFileIDs.add(file.uploadedFileID!);
}

return matchingFileIDs;
}

void _addToQueue(EnteFile file) {
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
Expand Down
225 changes: 225 additions & 0 deletions mobile/lib/services/magic_cache_service.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import "dart:async";
import "dart:convert";
import "dart:io";

import "package:logging/logging.dart";
import "package:path_provider/path_provider.dart";
import "package:photos/models/file/file.dart";
import "package:photos/models/search/generic_search_result.dart";
import "package:photos/models/search/search_types.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/services/search_service.dart";
import "package:shared_preferences/shared_preferences.dart";

class MagicCache {
final String title;
final List<int> fileUploadedIDs;
MagicCache(this.title, this.fileUploadedIDs);

factory MagicCache.fromJson(Map<String, dynamic> json) {
return MagicCache(
json['title'],
List<int>.from(json['fileUploadedIDs']),
);
}

Map<String, dynamic> toJson() {
return {
'title': title,
'fileUploadedIDs': fileUploadedIDs,
};
}

static String encodeListToJson(List<MagicCache> magicCaches) {
final jsonList = magicCaches.map((cache) => cache.toJson()).toList();
return jsonEncode(jsonList);
}

static List<MagicCache> decodeJsonToList(String jsonString) {
final jsonList = jsonDecode(jsonString) as List;
return jsonList.map((json) => MagicCache.fromJson(json)).toList();
}
}

extension MagicCacheServiceExtension on MagicCache {
Future<GenericSearchResult> toGenericSearchResult() async {
final allEnteFiles = await SearchService.instance.getAllFiles();
final enteFilesInMagicCache = <EnteFile>[];
for (EnteFile file in allEnteFiles) {
if (file.uploadedFileID != null &&
fileUploadedIDs.contains(file.uploadedFileID as int)) {
enteFilesInMagicCache.add(file);
}
}
return GenericSearchResult(
ResultType.magic,
title,
enteFilesInMagicCache,
);
}
}

class MagicCacheService {
static const _lastMagicCacheUpdateTime = "last_magic_cache_update_time";
static const _kMagicPromptsDataUrl = "https://discover.ente.io/v1.json";

/// Delay is for cache update to be done not during app init, during which a
/// lot of other things are happening.
static const _kCacheUpdateDelay = Duration(seconds: 10);

late SharedPreferences _prefs;
final Logger _logger = Logger((MagicCacheService).toString());
MagicCacheService._privateConstructor();

static final MagicCacheService instance =
MagicCacheService._privateConstructor();

void init(SharedPreferences preferences) {
_prefs = preferences;
_updateCacheIfTheTimeHasCome();
}

Future<void> resetLastMagicCacheUpdateTime() async {
await _prefs.setInt(
_lastMagicCacheUpdateTime,
DateTime.now().millisecondsSinceEpoch,
);
}

int get lastMagicCacheUpdateTime {
return _prefs.getInt(_lastMagicCacheUpdateTime) ?? 0;
}

Future<void> _updateCacheIfTheTimeHasCome() async {
final jsonFile = await RemoteAssetsService.instance
.getAssetIfUpdated(_kMagicPromptsDataUrl);
if (jsonFile != null) {
Future.delayed(_kCacheUpdateDelay, () {
unawaited(_updateCache());
});
return;
}
if (lastMagicCacheUpdateTime <
DateTime.now()
.subtract(const Duration(days: 3))
.millisecondsSinceEpoch) {
Future.delayed(_kCacheUpdateDelay, () {
unawaited(_updateCache());
});
}
}

Future<String> _getCachePath() async {
return (await getApplicationSupportDirectory()).path + "/cache/magic_cache";
}

Future<List<int>> _getMatchingFileIDsForPromptData(
Map<String, dynamic> promptData,
) async {
final result = await SemanticSearchService.instance.getMatchingFileIDs(
promptData["prompt"] as String,
promptData["minimumScore"] as double,
);

return result;
}

Future<void> _updateCache() async {
try {
_logger.info("updating magic cache");
final magicPromptsData = await _loadMagicPrompts();
final magicCaches = await nonEmptyMagicResults(magicPromptsData);
final file = File(await _getCachePath());
if (!file.existsSync()) {
file.createSync(recursive: true);
}
file.writeAsBytesSync(MagicCache.encodeListToJson(magicCaches).codeUnits);
unawaited(
resetLastMagicCacheUpdateTime().onError((error, stackTrace) {
_logger.warning(
"Error resetting last magic cache update time",
error,
);
}),
);
} catch (e) {
_logger.info("Error updating magic cache", e);
}
}

Future<List<MagicCache>?> _getMagicCache() async {
final file = File(await _getCachePath());
if (!file.existsSync()) {
_logger.info("No magic cache found");
return null;
}
final jsonString = file.readAsStringSync();
return MagicCache.decodeJsonToList(jsonString);
}

Future<void> clearMagicCache() async {
File(await _getCachePath()).deleteSync();
}

Future<List<GenericSearchResult>> getMagicGenericSearchResult() async {
try {
final magicCaches = await _getMagicCache();
if (magicCaches == null) {
_logger.info("No magic cache found");
return [];
}

final List<GenericSearchResult> genericSearchResults = [];
for (MagicCache magicCache in magicCaches) {
final genericSearchResult = await magicCache.toGenericSearchResult();
genericSearchResults.add(genericSearchResult);
}
return genericSearchResults;
} catch (e) {
_logger.info("Error getting magic generic search result", e);
return [];
}
}

Future<List<dynamic>> _loadMagicPrompts() async {
final file =
await RemoteAssetsService.instance.getAsset(_kMagicPromptsDataUrl);

final json = jsonDecode(await file.readAsString());
return json["prompts"];
}

///Returns random non-empty magic results from magicPromptsData
///Length is capped at [limit], can be less than [limit] if there are not enough
///non-empty results
Future<List<MagicCache>> nonEmptyMagicResults(
List<dynamic> magicPromptsData,
) async {
//Show all magic prompts to internal users for feedback on results
final limit = flagService.internalUser ? magicPromptsData.length : 4;
final results = <MagicCache>[];
final randomIndexes = List.generate(
magicPromptsData.length,
(index) => index,
growable: false,
)..shuffle();
for (final index in randomIndexes) {
final files =
await _getMatchingFileIDsForPromptData(magicPromptsData[index]);
if (files.isNotEmpty) {
results.add(
MagicCache(
magicPromptsData[index]["title"] as String,
files,
),
);
}
if (results.length >= limit) {
break;
}
}
return results;
}
}
37 changes: 33 additions & 4 deletions mobile/lib/services/remote_assets_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,46 @@ class RemoteAssetsService {
Future<File> getAsset(String remotePath, {bool refetch = false}) async {
final path = await _getLocalPath(remotePath);
final file = File(path);
if (await file.exists() && !refetch) {
if (file.existsSync() && !refetch) {
_logger.info("Returning cached file for $remotePath");
return file;
} else {
final tempFile = File(path + ".temp");
await _downloadFile(remotePath, tempFile.path);
await tempFile.rename(path);
tempFile.renameSync(path);
return File(path);
}
}

///Returns asset if the remote asset is new compared to the local copy of it
Future<File?> getAssetIfUpdated(String remotePath) async {
try {
final path = await _getLocalPath(remotePath);
final file = File(path);
if (!file.existsSync()) {
final tempFile = File(path + ".temp");
await _downloadFile(remotePath, tempFile.path);
tempFile.renameSync(path);
return File(path);
} else {
final existingFileSize = File(path).lengthSync();
final tempFile = File(path + ".temp");
await _downloadFile(remotePath, tempFile.path);
final newFileSize = tempFile.lengthSync();
if (existingFileSize != newFileSize) {
tempFile.renameSync(path);
return File(path);
} else {
tempFile.deleteSync();
return null;
}
}
} catch (e) {
_logger.warning("Error getting asset if updated", e);
return null;
}
}

Future<bool> hasAsset(String remotePath) async {
final path = await _getLocalPath(remotePath);
return File(path).exists();
Expand Down Expand Up @@ -60,8 +89,8 @@ class RemoteAssetsService {
Future<void> _downloadFile(String url, String savePath) async {
_logger.info("Downloading " + url);
final existingFile = File(savePath);
if (await existingFile.exists()) {
await existingFile.delete();
if (existingFile.existsSync()) {
existingFile.deleteSync();
}

await NetworkClient.instance.getDio().download(
Expand Down
Loading

0 comments on commit f67dc48

Please sign in to comment.