Skip to content

Commit

Permalink
Refactoring, code cleanup, razor UI progress, reorganize unit tests (#95
Browse files Browse the repository at this point in the history
)
  • Loading branch information
BeepBeepBopBop authored Jan 11, 2025
1 parent 370d2f5 commit 29099ba
Show file tree
Hide file tree
Showing 35 changed files with 606 additions and 562 deletions.
20 changes: 4 additions & 16 deletions LM-Kit-Maestro/AppShell.xaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,9 @@
>
<!-- WinUI TitleBar issue: https://stackoverflow.com/questions/78200704/net-maui-flyout-menu-is-overlapping-the-windows-title-bar-->

<!--<FlyoutItem Route="HomePage" Title="Home">
<ShellContent ContentTemplate="{DataTemplate ui:HomePage}"/>
</FlyoutItem>-->

<!--<Tab BindingContext="{Binding HomeTab}" Title="{Binding HomeTab.Title}" Route="HomePage">
<ShellContent Title="Home" ContentTemplate="{DataTemplate ui:HomePage}"/>
</Tab>-->

<!--<Tab BindingContext="{Binding AssistantsTab}" Title="{Binding AssistantsTab.Title}" Route="AssistantsPage" x:DataType="vm:MaestroTabViewModel">
<ShellContent ContentTemplate="{DataTemplate ui:AssistantsPage}"/>
</Tab>-->
<!-- <Tab BindingContext="{Binding AssistantsTab}" Title="{Binding AssistantsTab.Title}" Route="AssistantsPage" x:DataType="vm:MaestroTabViewModel"> -->
<!-- <ShellContent ContentTemplate="{DataTemplate ui:AssistantsPage}"/> -->
<!-- </Tab> -->

<Tab BindingContext="{Binding ChatTab}" Title="{Binding ChatTab.Title}" Route="ChatPage" x:DataType="vm:MaestroTabViewModel">
<ShellContent Title="AI Chat" ContentTemplate="{DataTemplate ui:ChatPage}"/>
Expand All @@ -37,12 +29,8 @@
<ShellContent Title="Models" ContentTemplate="{DataTemplate ui:ModelsPage}"/>
</Tab>

<!--<FlyoutItem Route="ModelsPage" Title="Models">
<ShellContent ContentTemplate="{DataTemplate ui:ModelsPage}"/>
</FlyoutItem>-->

<shell:SimpleShell.RootPageContainer>
<Grid x:Name="rootPageContainer">
<Grid>
<shell:SimpleNavigationHost/>
</Grid>
</shell:SimpleShell.RootPageContainer>
Expand Down
78 changes: 39 additions & 39 deletions LM-Kit-Maestro/Services/LMKitService.Chat.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
using LMKit.TextGeneration;
using LMKit.TextGeneration.Chat;
using LMKit.Translation;
using System.ComponentModel;
using System.Diagnostics;

namespace LMKit.Maestro.Services;

public partial class LMKitService : INotifyPropertyChanged
public partial class LMKitService
{
public partial class LMKitChat : INotifyPropertyChanged
{
Expand All @@ -16,27 +14,27 @@ public partial class LMKitChat : INotifyPropertyChanged

public event PropertyChangedEventHandler? PropertyChanged;

private readonly LMKitConfig _config;
private readonly LMKitServiceState _state;

public LMKitChat(LMKitConfig config, SemaphoreSlim lmKitServiceSemaphore)
private MultiTurnConversation? _multiTurnConversation;
private Conversation? _lastConversationUsed;

public LMKitChat(LMKitServiceState state)
{
_config = config;
_lmKitServiceSemaphore = lmKitServiceSemaphore;
_state = state;
}

public async Task<LMKitResult> SubmitPrompt(Conversation conversation, string prompt)
{
var promptRequest = new LMKitRequest(conversation, LMKitRequest.LMKitRequestType.Prompt, prompt, _config.RequestTimeout)
{
Conversation = conversation
};
var promptRequest = new ChatRequest(conversation, ChatRequest.ChatRequestType.Prompt,
prompt, _state.Config.RequestTimeout);

ScheduleRequest(promptRequest);

return await HandlePrompt(promptRequest);
}

public void CancelAllPrompts()
public void TerminateChatService()
{
if (_requestSchedule.Count > 0)
{
Expand All @@ -50,7 +48,6 @@ public void CancelAllPrompts()
_requestSchedule.Next!.CancelAndAwaitTermination();
}
}


if (_titleGenerationSchedule.RunningPromptRequest != null && !_titleGenerationSchedule.RunningPromptRequest.CancellationTokenSource.IsCancellationRequested)
{
Expand All @@ -60,6 +57,12 @@ public void CancelAllPrompts()
{
_titleGenerationSchedule.Next!.CancelAndAwaitTermination();
}

if (_multiTurnConversation != null)
{
_multiTurnConversation.Dispose();
_multiTurnConversation = null;
}
}

public async Task CancelPrompt(Conversation conversation, bool shouldAwaitTermination = false)
Expand All @@ -83,19 +86,16 @@ public async Task CancelPrompt(Conversation conversation, bool shouldAwaitTermin

public async Task<LMKitResult> RegenerateResponse(Conversation conversation, ChatHistory.Message message)
{
var regenerateResponseRequest = new LMKitRequest(conversation,
LMKitRequest.LMKitRequestType.RegenerateResponse,
message, _config.RequestTimeout)
{
Conversation = conversation
};
var regenerateResponseRequest = new ChatRequest(conversation,
ChatRequest.ChatRequestType.RegenerateResponse,
message, _state.Config.RequestTimeout);

ScheduleRequest(regenerateResponseRequest);

return await HandlePrompt(regenerateResponseRequest);
}

private void ScheduleRequest(LMKitRequest request)
private void ScheduleRequest(ChatRequest request)
{
_requestSchedule.Schedule(request);

Expand All @@ -105,7 +105,7 @@ private void ScheduleRequest(LMKitRequest request)
}
}

private async Task<LMKitResult> HandlePrompt(LMKitRequest request)
private async Task<LMKitResult> HandlePrompt(ChatRequest request)
{
LMKitResult result;

Expand Down Expand Up @@ -157,7 +157,7 @@ private async Task<LMKitResult> HandlePrompt(LMKitRequest request)
return result;
}

private async Task<LMKitResult> SubmitPrompt(LMKitRequest request)
private async Task<LMKitResult> SubmitPrompt(ChatRequest request)
{
try
{
Expand All @@ -168,11 +168,11 @@ private async Task<LMKitResult> SubmitPrompt(LMKitRequest request)

try
{
if (request.RequestType == LMKitRequest.LMKitRequestType.Prompt)
if (request.RequestType == ChatRequest.ChatRequestType.Prompt)
{
result.Result = await _multiTurnConversation!.SubmitAsync((string)request.Parameters!, request.CancellationTokenSource.Token);
}
else if (request.RequestType == LMKitRequest.LMKitRequestType.RegenerateResponse)
else if (request.RequestType == ChatRequest.ChatRequestType.RegenerateResponse)
{
result.Result = await _multiTurnConversation!.RegenerateResponseAsync(request.CancellationTokenSource.Token);
}
Expand Down Expand Up @@ -220,7 +220,7 @@ private async Task<LMKitResult> SubmitPrompt(LMKitRequest request)
private void GenerateConversationSummaryTitle(Conversation conversation)
{
string firstMessage = conversation.ChatHistory!.Messages.First(message => message.AuthorRole == AuthorRole.User).Content;
LMKitRequest titleGenerationRequest = new LMKitRequest(conversation, LMKitRequest.LMKitRequestType.GenerateTitle, firstMessage, 60);
ChatRequest titleGenerationRequest = new ChatRequest(conversation, ChatRequest.ChatRequestType.GenerateTitle, firstMessage, 60);

_titleGenerationSchedule.Schedule(titleGenerationRequest);

Expand All @@ -233,7 +233,7 @@ private void GenerateConversationSummaryTitle(Conversation conversation)

Task.Run(async () =>
{
Summarizer summarizer = new Summarizer(_model)
Summarizer summarizer = new Summarizer(_state.LoadedModel)
{
MaximumContextLength = 512,
GenerateContent = false,
Expand Down Expand Up @@ -276,42 +276,42 @@ private void BeforeSubmittingPrompt(Conversation conversation)
}

// Latest chat history of this conversation was generated with a different model
bool lastUsedDifferentModel = _config.LoadedModelUri != conversation.LastUsedModelUri;
bool lastUsedDifferentModel = _state.Config.LoadedModelUri != conversation.LastUsedModelUri;
bool shouldUseCurrentChatHistory = !lastUsedDifferentModel && conversation.ChatHistory != null;
bool shouldDeserializeChatHistoryData = (lastUsedDifferentModel && conversation.LatestChatHistoryData != null) || (!lastUsedDifferentModel && conversation.ChatHistory == null);

if (shouldUseCurrentChatHistory || shouldDeserializeChatHistoryData)
{
ChatHistory? chatHistory = shouldUseCurrentChatHistory ? conversation.ChatHistory : ChatHistory.Deserialize(conversation.LatestChatHistoryData, _model);
ChatHistory? chatHistory = shouldUseCurrentChatHistory ? conversation.ChatHistory : ChatHistory.Deserialize(conversation.LatestChatHistoryData, _state.LoadedModel);

_multiTurnConversation = new MultiTurnConversation(_model, chatHistory, _config.ContextSize)
_multiTurnConversation = new MultiTurnConversation(_state.LoadedModel, chatHistory, _state.Config.ContextSize)
{
SamplingMode = GetTokenSampling(_config),
MaximumCompletionTokens = _config.MaximumCompletionTokens,
SamplingMode = GetTokenSampling(_state.Config),
MaximumCompletionTokens = _state.Config.MaximumCompletionTokens,
};
}
else
{
_multiTurnConversation = new MultiTurnConversation(_model, _config.ContextSize)
_multiTurnConversation = new MultiTurnConversation(_state.LoadedModel, _state.Config.ContextSize)
{
SamplingMode = GetTokenSampling(_config),
MaximumCompletionTokens = _config.MaximumCompletionTokens,
SystemPrompt = _config.SystemPrompt
SamplingMode = GetTokenSampling(_state.Config),
MaximumCompletionTokens = _state.Config.MaximumCompletionTokens,
SystemPrompt = _state.Config.SystemPrompt
};
}
_multiTurnConversation.AfterTokenSampling += conversation.AfterTokenSampling;

conversation.ChatHistory = _multiTurnConversation.ChatHistory;
conversation.LastUsedModelUri = _config.LoadedModelUri;
conversation.LastUsedModelUri = _state.Config.LoadedModelUri;
_lastConversationUsed = conversation;
}
else //updating sampling options, if any.
{
//todo: Implement a mechanism to determine whether SamplingMode and MaximumCompletionTokens need to be updated.
_multiTurnConversation!.SamplingMode = GetTokenSampling(_config);
_multiTurnConversation.MaximumCompletionTokens = _config.MaximumCompletionTokens;
_multiTurnConversation!.SamplingMode = GetTokenSampling(_state.Config);
_multiTurnConversation.MaximumCompletionTokens = _state.Config.MaximumCompletionTokens;

if (_config.ContextSize != _multiTurnConversation.ContextSize)
if (_state.Config.ContextSize != _multiTurnConversation.ContextSize)
{
//todo: implement context size update.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

public partial class LMKitService
{
private sealed partial class LMKitRequest
private sealed partial class ChatRequest
{
public enum LMKitRequestType
public enum ChatRequestType
{
Prompt,
RegenerateResponse,
Expand Down

This file was deleted.

20 changes: 20 additions & 0 deletions LM-Kit-Maestro/Services/LMKitService.LMKitServiceState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using LMKit.Model;
using System.ComponentModel;

namespace LMKit.Maestro.Services;

public partial class LMKitService
{
public partial class LMKitServiceState
{
public LMKitConfig Config { get; } = new LMKitConfig();

public SemaphoreSlim Semaphore { get; } = new SemaphoreSlim(1);

public LM? LoadedModel { get; set; }

public Uri? LoadedModelUri { get; set; }

public LMKitModelLoadingState ModelLoadingState { get; set; }
}
}
10 changes: 5 additions & 5 deletions LM-Kit-Maestro/Services/LMKitService.LmKitRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

namespace LMKit.Maestro.Services;

public partial class LMKitService : INotifyPropertyChanged
public partial class LMKitService
{
private sealed partial class LMKitRequest
private sealed partial class ChatRequest
{
public ManualResetEvent CanBeExecutedSignal { get; } = new ManualResetEvent(false);
public CancellationTokenSource CancellationTokenSource { get; }
public TaskCompletionSource<LMKitResult> ResponseTask { get; } = new TaskCompletionSource<LMKitResult>();
public object? Parameters { get; }
public Conversation Conversation { get; set; }
public Conversation Conversation { get; }

public LMKitRequestType RequestType { get; }
public ChatRequestType RequestType { get; }

public LMKitRequest(Conversation conversation, LMKitRequestType requestType, object? parameter, int requestTimeout)
public ChatRequest(Conversation conversation, ChatRequestType requestType, object? parameter, int requestTimeout)
{
Conversation = conversation;
RequestType = requestType;
Expand Down
20 changes: 10 additions & 10 deletions LM-Kit-Maestro/Services/LMKitService.RequestSchedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

namespace LMKit.Maestro.Services;

public partial class LMKitService : INotifyPropertyChanged
public partial class LMKitService
{
private sealed class RequestSchedule
{
private readonly object _locker = new object();

private List<LMKitRequest> _scheduledPrompts = new List<LMKitRequest>();
private List<ChatRequest> _scheduledPrompts = new List<ChatRequest>();

public int Count
{
Expand All @@ -21,7 +21,7 @@ public int Count
}
}

public LMKitRequest? Next
public ChatRequest? Next
{
get
{
Expand All @@ -39,9 +39,9 @@ public LMKitRequest? Next
}
}

public LMKitRequest? RunningPromptRequest { get; set; }
public ChatRequest? RunningPromptRequest { get; set; }

public void Schedule(LMKitRequest promptRequest)
public void Schedule(ChatRequest promptRequest)
{
lock (_locker)
{
Expand All @@ -54,25 +54,25 @@ public void Schedule(LMKitRequest promptRequest)
}
}

public bool Contains(LMKitRequest scheduledPrompt)
public bool Contains(ChatRequest scheduledPrompt)
{
lock (_locker)
{
return _scheduledPrompts.Contains(scheduledPrompt);
}
}

public void Remove(LMKitRequest scheduledPrompt)
public void Remove(ChatRequest scheduledPrompt)
{
lock (_locker)
{
HandleScheduledPromptRemoval(scheduledPrompt);
}
}

public LMKitRequest? Unschedule(Conversation conversation)
public ChatRequest? Unschedule(Conversation conversation)
{
LMKitRequest? prompt = null;
ChatRequest? prompt = null;

lock (_locker)
{
Expand All @@ -94,7 +94,7 @@ public void Remove(LMKitRequest scheduledPrompt)
return prompt;
}

private void HandleScheduledPromptRemoval(LMKitRequest scheduledPrompt)
private void HandleScheduledPromptRemoval(ChatRequest scheduledPrompt)
{
bool wasFirstInLine = scheduledPrompt == _scheduledPrompts[0];

Expand Down
Loading

0 comments on commit 29099ba

Please sign in to comment.