Skip to content

Commit

Permalink
enable websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyin committed Oct 26, 2024
1 parent ffbf460 commit 92c28da
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 44 deletions.
1 change: 1 addition & 0 deletions backend/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ repositories {
dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-websocket'
implementation 'org.flywaydb:flyway-mysql'
implementation 'com.kjetland:mbknor-jackson-jsonschema_2.13:1.0.39'
compileOnly 'org.flywaydb:flyway-core'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.odde.doughnut.configs;

import com.odde.doughnut.factoryServices.ModelFactoryService;
import com.odde.doughnut.handlers.AudioWebSocketHandler;
import com.theokanning.openai.client.OpenAiApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

@Autowired private ModelFactoryService modelFactoryService;

@Autowired
@Qualifier("testableOpenAiApi")
private OpenAiApi openAiApi;

@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(audioWebSocketHandler(), "/ws/audio").setAllowedOrigins("*");
}

@Bean
public WebSocketHandler audioWebSocketHandler() {
return new AudioWebSocketHandler(openAiApi, modelFactoryService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@RestController
@SessionScope
@RequestMapping("/api/notes")
@RequestMapping("/api/audio")
class RestAiAudioController {

private final AiAdvisorService aiAdvisorService;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.odde.doughnut.handlers;

import com.odde.doughnut.factoryServices.ModelFactoryService;
import com.odde.doughnut.services.AiAdvisorService;
import com.odde.doughnut.services.GlobalSettingsService;
import com.odde.doughnut.services.ai.TextFromAudio;
import com.theokanning.openai.client.OpenAiApi;
import java.io.IOException;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.BinaryWebSocketHandler;

public class AudioWebSocketHandler extends BinaryWebSocketHandler {

private final AiAdvisorService aiAdvisorService;
private final ModelFactoryService modelFactoryService;

@Autowired
public AudioWebSocketHandler(
@Qualifier("testableOpenAiApi") OpenAiApi openAiApi,
ModelFactoryService modelFactoryService) {
this.modelFactoryService = modelFactoryService;
this.aiAdvisorService = new AiAdvisorService(openAiApi);
}

@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message)
throws IOException {
byte[] audioData = message.getPayload().array();

Optional<TextFromAudio> result =
aiAdvisorService
.getOtherAiServices()
.getTextFromAudio(
getPreviousNoteDetails(session),
"stream.wav",
audioData,
getGlobalSettingsService().globalSettingOthers().getValue());

if (result.isPresent()) {
session.sendMessage(new TextMessage(result.get().getCompletionMarkdownFromAudio()));
}
}

private GlobalSettingsService getGlobalSettingsService() {
return new GlobalSettingsService(modelFactoryService);
}

private String getPreviousNoteDetails(WebSocketSession session) {
// Implement logic to get previous note details from session
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.mockito.Mockito.*;

import com.odde.doughnut.controllers.dto.*;
import com.odde.doughnut.models.UserModel;
import com.odde.doughnut.services.ai.TextFromAudio;
import com.odde.doughnut.services.openAiApis.OpenAiApiExtended;
import com.odde.doughnut.testability.MakeMe;
Expand Down Expand Up @@ -36,15 +35,12 @@
@Transactional
class RestAiAudioControllerTests {
@Autowired MakeMe makeMe;
private UserModel userModel;
RestAiAudioController controller;
@Mock OpenAiApiExtended openAiApi;
OpenAIChatCompletionMock openAIChatCompletionMock;

@BeforeEach
void setup() {
userModel = makeMe.aUser().toModelPlease();

controller = new RestAiAudioController(openAiApi, makeMe.modelFactoryService);
TextFromAudio completionMarkdownFromAudio = new TextFromAudio();
completionMarkdownFromAudio.setCompletionMarkdownFromAudio("test123");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.odde.doughnut.handlers;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import com.odde.doughnut.services.ai.TextFromAudio;
import com.odde.doughnut.services.openAiApis.OpenAiApiExtended;
import com.odde.doughnut.testability.MakeMe;
import com.odde.doughnut.testability.OpenAIChatCompletionMock;
import io.reactivex.Single;
import java.io.IOException;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

@SpringBootTest
@ActiveProfiles("test")
@Transactional
class AudioWebSocketHandlerTests {

@Autowired MakeMe makeMe;

@Mock OpenAiApiExtended openAiApi;

@Mock WebSocketSession session;

private AudioWebSocketHandler handler;
OpenAIChatCompletionMock openAIChatCompletionMock;

@BeforeEach
void setup() {
handler = new AudioWebSocketHandler(openAiApi, makeMe.modelFactoryService);
when(openAiApi.createTranscriptionSrt(any(RequestBody.class)))
.thenReturn(Single.just(ResponseBody.create("test", null)));
TextFromAudio completionMarkdownFromAudio = new TextFromAudio();
completionMarkdownFromAudio.setCompletionMarkdownFromAudio("test123");
openAIChatCompletionMock = new OpenAIChatCompletionMock(openAiApi);
openAIChatCompletionMock.mockChatCompletionAndReturnToolCall(
completionMarkdownFromAudio, "audio_transcription_to_text");
}

@Test
void handleBinaryMessage_shouldSendTextMessageWhenTextFromAudioIsPresent() throws IOException {
byte[] audioData = "test audio data".getBytes();
BinaryMessage binaryMessage = new BinaryMessage(audioData);

TextFromAudio textFromAudio = new TextFromAudio();
textFromAudio.setCompletionMarkdownFromAudio("Transcribed text");

handler.handleBinaryMessage(session, binaryMessage);

verify(session).sendMessage(any(TextMessage.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export class RestAiAudioControllerService {
): CancelablePromise<TextFromAudio> {
return this.httpRequest.request({
method: 'POST',
url: '/api/notes/audio-to-text',
url: '/api/audio/audio-to-text',
formData: formData,
mediaType: 'multipart/form-data',
errors: {
Expand Down
76 changes: 38 additions & 38 deletions open_api_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1102,29 +1102,6 @@ paths:
type: array
items:
$ref: "#/components/schemas/NoteRealm"
/api/notes/audio-to-text:
post:
tags:
- rest-ai-audio-controller
operationId: audioToText
requestBody:
content:
multipart/form-data:
schema:
$ref: "#/components/schemas/AudioUploadDTO"
responses:
"500":
description: Internal Server Error
content:
'*/*':
schema:
type: string
"200":
description: OK
content:
'*/*':
schema:
$ref: "#/components/schemas/TextFromAudio"
/api/notebooks/{notebook}:
get:
tags:
Expand Down Expand Up @@ -1665,6 +1642,29 @@ paths:
type: array
items:
$ref: "#/components/schemas/BazaarNotebook"
/api/audio/audio-to-text:
post:
tags:
- rest-ai-audio-controller
operationId: audioToText
requestBody:
content:
multipart/form-data:
schema:
$ref: "#/components/schemas/AudioUploadDTO"
responses:
"500":
description: Internal Server Error
content:
'*/*':
schema:
type: string
"200":
description: OK
content:
'*/*':
schema:
$ref: "#/components/schemas/TextFromAudio"
/api/assessment/{assessmentQuestionInstance}/answer:
post:
tags:
Expand Down Expand Up @@ -3441,21 +3441,6 @@ components:
noteId:
type: integer
format: int32
AudioUploadDTO:
type: object
properties:
uploadAudioFile:
type: string
format: binary
previousNoteDetails:
type: string
TextFromAudio:
required:
- completionMarkdownFromAudio
type: object
properties:
completionMarkdownFromAudio:
type: string
NotebookCertificateApproval:
required:
- id
Expand Down Expand Up @@ -3598,6 +3583,21 @@ components:
format: int32
notebook:
$ref: "#/components/schemas/Notebook"
AudioUploadDTO:
type: object
properties:
uploadAudioFile:
type: string
format: binary
previousNoteDetails:
type: string
TextFromAudio:
required:
- completionMarkdownFromAudio
type: object
properties:
completionMarkdownFromAudio:
type: string
AssessmentAttempt:
required:
- id
Expand Down

0 comments on commit 92c28da

Please sign in to comment.