Skip to content

Commit

Permalink
Translation feature #20 (also a hint of #16)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCsabaToth committed Aug 10, 2024
1 parent 2433a77 commit 39bef0e
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 64 deletions.
40 changes: 29 additions & 11 deletions lib/ai/cubit/ai_cubit.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import 'package:flutter/foundation.dart';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:inspector_gadget/ai/prompts/resolver_few_shot.dart';
import 'package:inspector_gadget/ai/prompts/stuffed_user_utterance.dart';
import 'package:inspector_gadget/ai/prompts/system.dart';
import 'package:inspector_gadget/ai/prompts/system_instruction.dart';
import 'package:inspector_gadget/ai/prompts/translate_instruction.dart';
import 'package:inspector_gadget/ai/tools/tools_mixin.dart';
import 'package:inspector_gadget/database/cubit/database_cubit.dart';
import 'package:inspector_gadget/database/models/history.dart';
Expand All @@ -18,6 +19,20 @@ class AiCubit extends Cubit<int> with ToolsMixin {

ChatSession? chat;

GenerativeModel getModel(
PreferencesState? preferencesState, {
bool withTools = true,
}) {
final fastMode =
preferencesState?.fastLlmMode ?? PreferencesState.fastLlmModeDefault;
final modelType = fastMode ? 'flash' : 'pro';
return GenerativeModel(
model: 'gemini-1.5-$modelType',
apiKey: preferencesState?.geminiApiKey ?? geminiApiKey,
tools: withTools ? [getFunctionDeclarations(preferencesState)] : null,
);
}

Future<GenerateContentResponse?> chatStep(
String prompt,
DatabaseCubit? database,
Expand All @@ -26,16 +41,7 @@ class AiCubit extends Cubit<int> with ToolsMixin {
Location? gpsLocation,
) async {
if (chat != null) {
final fastMode =
preferencesState?.fastLlmMode ?? PreferencesState.fastLlmModeDefault;
final tools = [getFunctionDeclarations(preferencesState)];
final modelType = fastMode ? 'flash' : 'pro';
final model = GenerativeModel(
model: 'gemini-1.5-$modelType',
apiKey: preferencesState?.geminiApiKey ?? geminiApiKey,
tools: tools,
);

final model = getModel(preferencesState);
final stuffedInstruction = systemInstruction.replaceAll(
'%%%',
getFunctionCallPromptStuffing(preferencesState),
Expand Down Expand Up @@ -175,4 +181,16 @@ class AiCubit extends Cubit<int> with ToolsMixin {

return response.text ?? '';
}

Future<GenerateContentResponse?> translate(
String transcript,
String targetLocale,
PreferencesState? preferencesState,
) async {
final model = getModel(preferencesState, withTools: false);
final stuffedPrompt = translateInstruction.replaceAll('%%%', targetLocale);
final content = Content.text(stuffedPrompt);
final response = await model.generateContent([content]);
return response;
}
}
File renamed without changes.
1 change: 1 addition & 0 deletions lib/ai/prompts/translate_instruction.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
const translateInstruction = 'Translate all the following to %%% locale: ';
26 changes: 15 additions & 11 deletions lib/database/view/personalization_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class _PersonalizationViewState extends State<PersonalizationView>
late AnimationController _animationController;
int _editCount = 0;
PreferencesState? preferencesState;
String systemLocale = PreferencesState.inputLocaleDefault;
String inputLocaleId = PreferencesState.inputLocaleDefault;
AiCubit? aiCubit;
DatabaseCubit? database;
PersonalizationCubit? personalizationCubit;
Expand Down Expand Up @@ -94,11 +94,9 @@ class _PersonalizationViewState extends State<PersonalizationView>
final recorded = result.recognizedWords.trim();
if (recorded.isNotEmpty && database != null) {
personalizationCubit?.setState(PersonalizationCubit.processingStateLabel);

final embedding =
await aiCubit?.obtainEmbedding(recorded, preferencesState) ?? [];

final personalization = Personalization(recorded, systemLocale)
final personalization = Personalization(recorded, inputLocaleId)
..embedding = embedding;
database!.addUpdatePersonalization(personalization);

Expand All @@ -120,6 +118,8 @@ class _PersonalizationViewState extends State<PersonalizationView>
final l10n = context.l10n;
aiCubit = context.select((AiCubit cubit) => cubit);
preferencesState = context.select((PreferencesCubit cubit) => cubit.state);
inputLocaleId =
preferencesState?.inputLocale ?? PreferencesState.inputLocaleDefault;
database = context.select((DatabaseCubit cubit) => cubit);
personalizationCubit =
context.select((PersonalizationCubit cubit) => cubit);
Expand Down Expand Up @@ -172,11 +172,14 @@ class _PersonalizationViewState extends State<PersonalizationView>
?.setState(PersonalizationCubit.playingStateLabel);
final ttsState =
context.select((TtsCubit cubit) => cubit.state);
await ttsState.speak(
p13n.content,
preferencesState?.volume ??
PreferencesState.volumeDefault,
);
if (await ttsState.setLanguage(p13n.locale)) {
await ttsState.speak(
p13n.content,
preferencesState?.volume ??
PreferencesState.volumeDefault,
);
}

personalizationCubit
?.setState(PersonalizationCubit.browsingStateLabel);
},
Expand Down Expand Up @@ -238,7 +241,8 @@ class _PersonalizationViewState extends State<PersonalizationView>
);

final sttState = context.select((SttCubit cubit) => cubit.state);
systemLocale = sttState.systemLocale;
inputLocaleId =
preferencesState?.inputLocale ?? sttState.systemLocale;
await sttState.speech.listen(
onResult: _resultListener,
listenFor: const Duration(
Expand All @@ -247,7 +251,7 @@ class _PersonalizationViewState extends State<PersonalizationView>
pauseFor: const Duration(
seconds: PreferencesState.pauseForDefault,
),
localeId: systemLocale,
localeId: inputLocaleId,
onSoundLevelChange: _soundLevelListener,
listenOptions: options,
);
Expand Down
106 changes: 76 additions & 30 deletions lib/interaction/view/interaction_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import 'package:flutter/foundation.dart';
import 'package:flutter/material.dart';
import 'package:flutter_bloc/flutter_bloc.dart';
import 'package:flutter_easy_animations/flutter_easy_animations.dart';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:http/http.dart' as http;
import 'package:inspector_gadget/ai/cubit/ai_cubit.dart';
import 'package:inspector_gadget/database/cubit/database_cubit.dart';
Expand All @@ -25,6 +26,7 @@ import 'package:inspector_gadget/preferences/cubit/preferences_cubit.dart';
import 'package:inspector_gadget/preferences/cubit/preferences_state.dart';
import 'package:inspector_gadget/preferences/preferences.dart';
import 'package:inspector_gadget/secrets.dart';
import 'package:inspector_gadget/string_ex.dart';
import 'package:inspector_gadget/stt/cubit/stt_cubit.dart';
import 'package:inspector_gadget/tts/cubit/tts_cubit.dart';
import 'package:inspector_gadget/tts/cubit/tts_state.dart';
Expand Down Expand Up @@ -214,7 +216,11 @@ class _InteractionViewState extends State<InteractionView>
json.decode(transcriptionResponse.body) as List<dynamic>;
final transcripts = Transcriptions.fromJson(transcriptJson);
if (context.mounted) {
await _llmPhase(context, transcripts.merged);
await _llmPhase(
context,
transcripts.merged,
transcripts.localeMode(),
);
}
} else {
log('${transcriptionResponse.statusCode} '
Expand Down Expand Up @@ -247,27 +253,49 @@ class _InteractionViewState extends State<InteractionView>
}
/* END Android native STT utilities */

Future<void> _llmPhase(BuildContext context, String prompt) async {
Future<void> _llmPhase(
BuildContext context,
String prompt,
String locale,
) async {
mainCubit?.setState(MainCubit.llmStateLabel);

final newHeartRate = heartRateCubit?.state ?? 0;
if (newHeartRate > 0) {
heartRate = newHeartRate;
}
GenerateContentResponse? response;
var targetLocale = '';
final inputLocale =
(preferencesState?.inputLocale ?? PreferencesState.inputLocaleDefault)
.replaceAll('_', '-');
final outputLocale =
preferencesState?.outputLocale ?? PreferencesState.outputLocaleDefault;
if (widget.interactionMode == InteractionMode.translateMode) {
final matchedLocale = ttsState?.matchLanguage(locale) ?? outputLocale;
if (matchedLocale.localeMatch(inputLocale)) {
targetLocale = outputLocale;
} else {
// Also covers matchedLocale == outputLocale
targetLocale = inputLocale;
}
} else {
targetLocale = inputLocale;
final newHeartRate = heartRateCubit?.state ?? 0;
if (newHeartRate > 0) {
heartRate = newHeartRate;
}

final loc = await locationCubit?.obtain();
if (loc != null &&
(loc.latitude.abs() > 10e-6 || loc.longitude.abs() > 10e-6)) {
gpsLocation = loc;
}
final loc = await locationCubit?.obtain();
if (loc != null &&
(loc.latitude.abs() > 10e-6 || loc.longitude.abs() > 10e-6)) {
gpsLocation = loc;
}

final response = await aiCubit?.chatStep(
prompt,
databaseCubit,
preferencesState,
heartRate,
gpsLocation,
);
response = await aiCubit?.chatStep(
prompt,
databaseCubit,
preferencesState,
heartRate,
gpsLocation,
);
}

debugPrint('Final: ${response?.text}');
if (response == null ||
Expand All @@ -280,14 +308,18 @@ class _InteractionViewState extends State<InteractionView>
mainCubit?.setState(MainCubit.doneStateLabel);
} else if (context.mounted) {
if (areSpeechServicesNative) {
await _playbackPhase(context, response?.text ?? '', null);
await _playbackPhase(context, response?.text ?? '', null, targetLocale);
} else {
await _ttsPhase(context, response?.text ?? '');
await _ttsPhase(context, response?.text ?? '', targetLocale);
}
}
}

Future<void> _ttsPhase(BuildContext context, String responseText) async {
Future<void> _ttsPhase(
BuildContext context,
String responseText,
String locale,
) async {
mainCubit?.setState(MainCubit.ttsStateLabel);
try {
final ttsFullUrl = Uri.https(functionUrl, ttsEndpoint, {
Expand All @@ -299,7 +331,12 @@ class _InteractionViewState extends State<InteractionView>

if (synthetizationResponse.statusCode == 200) {
if (context.mounted) {
await _playbackPhase(context, '', synthetizationResponse.bodyBytes);
await _playbackPhase(
context,
'',
synthetizationResponse.bodyBytes,
locale,
);
}
} else {
log('${synthetizationResponse.statusCode} '
Expand All @@ -316,13 +353,16 @@ class _InteractionViewState extends State<InteractionView>
BuildContext context,
String responseText,
Uint8List? audioTrack,
String locale,
) async {
mainCubit?.setState(MainCubit.playingStateLabel);
if (responseText.isNotEmpty) {
await ttsState?.speak(
responseText,
preferencesState?.volume ?? PreferencesState.volumeDefault,
);
if (await ttsState?.setLanguage(locale) ?? false) {
await ttsState?.speak(
responseText,
preferencesState?.volume ?? PreferencesState.volumeDefault,
);
}
} else if (audioTrack.isNotEmptyOrNull) {
_player ??= Player();
final memoryMedia = await Media.memory(audioTrack!);
Expand All @@ -344,7 +384,9 @@ class _InteractionViewState extends State<InteractionView>
case ActionKind.initialize:
final sttState = context.select((SttCubit cubit) => cubit.state);
areSpeechServicesNative =
preferencesState!.areSpeechServicesNative && sttState.hasSpeech;
preferencesState!.areSpeechServicesNative &&
sttState.hasSpeech &&
widget.interactionMode != InteractionMode.translateMode;
if (areSpeechServicesNative) {
ttsState = context.select((TtsCubit cubit) => cubit.state);
}
Expand Down Expand Up @@ -396,7 +438,12 @@ class _InteractionViewState extends State<InteractionView>
?.set(PreferencesState.volumeTag, deferredAction.integer);

case ActionKind.speechTranscripted:
await _llmPhase(context, deferredAction.text);
final sttState = context.select((SttCubit cubit) => cubit.state);
await _llmPhase(
context,
deferredAction.text,
sttState.systemLocale,
);
}
}
}
Expand All @@ -418,8 +465,7 @@ class _InteractionViewState extends State<InteractionView>
deferredActionQueue.add(
DeferredAction(
ActionKind.speechTranscripted,
// text: "What is part 121G on O'Reilly Auto Parts?",
text: 'SpaceX Falcon 9 rocket',
text: "What is part 121G on O'Reilly Auto Parts?",
),
);
}
Expand Down
9 changes: 3 additions & 6 deletions lib/preferences/cubit/preferences_state.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'dart:io';

import 'package:flutter/foundation.dart';
import 'package:pref/pref.dart';
import 'package:strings/strings.dart';

class PreferencesState {
static BasePrefService? prefService;
Expand All @@ -24,7 +25,7 @@ class PreferencesState {
static const String inputLocaleTag = 'input_locale';
static const String inputLocaleDefault = 'en_US';
static const String outputLocaleTag = 'output_locale';
static const String outputLocaleDefault = 'en';
static const String outputLocaleDefault = 'en-US';
static const String llmDebugModeTag = 'llm_debug_mode';
static const bool llmDebugModeDefault = false;
static const int pauseForDefault = 3;
Expand Down Expand Up @@ -79,11 +80,7 @@ class PreferencesState {

static bool getUnitSystemDefault() {
final localeName = Platform.localeName;
if (localeName.length < 5 || localeName[2] != '_') {
return unitSystemDefault;
}

final deviceCountry = localeName.substring(3, 5);
final deviceCountry = localeName.right(2).toUpperCase();
return !imperialCountries.contains(deviceCountry);
}
}
9 changes: 9 additions & 0 deletions lib/string_ex.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import 'package:strings/strings.dart';

extension StringEx on String {
bool localeMatch(String other) {
return isNotEmpty &&
left(2).toLowerCase() == other.left(2).toLowerCase() &&
right(2).toUpperCase() == other.right(2).toUpperCase();
}
}
Loading

0 comments on commit 39bef0e

Please sign in to comment.