Skip to content

Commit

Permalink
Addressing feedback and adding test
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Feb 4, 2025
1 parent 7e452a3 commit c5f8ef6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ public void testRemoveDefaultConfigs_RemovesModelsFromPersistentStorage_AndInMem
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down Expand Up @@ -371,8 +370,7 @@ public void testGetAllModels_WithDefaults() throws Exception {
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down Expand Up @@ -437,8 +435,7 @@ public void testGetAllModels_OnlyDefaults() throws Exception {
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down Expand Up @@ -480,8 +477,7 @@ public void testGetAllModels_withDoNotPersist() throws Exception {
}

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down Expand Up @@ -525,8 +521,7 @@ public void testGet_WithDefaults() throws InterruptedException {
);

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(defaultConfigs);
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down Expand Up @@ -579,8 +574,7 @@ public void testGetByTaskType_WithDefaults() throws Exception {
defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(), service));

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
ActionListener<List<Model>> listener = invocation.getArgument(0);
listener.onResponse(List.of(defaultSparse, defaultChat, defaultText));
return Void.TYPE;
}).when(service).defaultConfigs(any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ private void getDefaultConfig(
}));
}

// TODO should we add a lock on the default model id so we can't attempt to delete it while we're adding it?
private void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
var responseListener = ActionListener.<Boolean>wrap(success -> {
logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.SearchResponseUtils;
Expand Down Expand Up @@ -310,6 +311,25 @@ public void testRemoveDefaultConfigs_DoesNotCallClient_WhenPassedAnEmptySet() {
verify(client, times(0)).execute(any(), any(), any());
}

public void testDeleteModels_Returns_ConflictException_WhenModelIsBeingAdded() {
var client = mockClient();

var registry = new ModelRegistry(client);
var model = TestModel.createRandomInstance();
var newModel = TestModel.createRandomInstance();
registry.updateModelTransaction(newModel, model, new PlainActionFuture<>());

var listener = new PlainActionFuture<Boolean>();

registry.deleteModels(Set.of(newModel.getInferenceEntityId()), listener);
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
exception.getMessage(),
containsString("are currently being updated, please wait until after they are finished updating to delete.")
);
assertThat(exception.status(), is(RestStatus.CONFLICT));
}

public void testIdMatchedDefault() {
var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
defaultConfigIds.add(
Expand Down

0 comments on commit c5f8ef6

Please sign in to comment.