Skip to content

Commit

Permalink
feat(core): Use operationCustomizer in all operation scanners (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
timonback authored Jul 13, 2024
1 parent b57bfd2 commit ef4c28c
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@ public class AsyncAnnotationOperationsScanner<A extends Annotation> extends Asyn
implements OperationsScanner {

private final ClassScanner classScanner;
private final List<OperationCustomizer> customizers;

public AsyncAnnotationOperationsScanner(
AsyncAnnotationProvider<A> asyncAnnotationProvider,
ClassScanner classScanner,
ComponentsService componentsService,
PayloadAsyncOperationService payloadAsyncOperationService,
List<OperationBindingProcessor> operationBindingProcessors,
List<MessageBindingProcessor> messageBindingProcessors) {
List<MessageBindingProcessor> messageBindingProcessors,
List<OperationCustomizer> customizers) {
super(
asyncAnnotationProvider,
payloadAsyncOperationService,
componentsService,
operationBindingProcessors,
messageBindingProcessors);
this.classScanner = classScanner;
this.customizers = customizers;
}

@Override
Expand All @@ -64,7 +67,7 @@ private Map.Entry<String, Operation> buildOperation(MethodAndAnnotation<A> metho

Operation operation = buildOperation(operationAnnotation, methodAndAnnotation.method(), channelId);
operation.setAction(this.asyncAnnotationProvider.getOperationType());

customizers.forEach(customizer -> customizer.customize(operation, methodAndAnnotation.method()));
return Map.entry(operationId, operation);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
Expand All @@ -30,14 +31,17 @@ public class SpringAnnotationClassLevelOperationsScanner<
extends ClassLevelAnnotationScanner<ClassAnnotation, MethodAnnotation>
implements SpringAnnotationOperationsScannerDelegator {

private final List<OperationCustomizer> customizers;

public SpringAnnotationClassLevelOperationsScanner(
Class<ClassAnnotation> classAnnotationClass,
Class<MethodAnnotation> methodAnnotationClass,
BindingFactory<ClassAnnotation> bindingFactory,
AsyncHeadersBuilder asyncHeadersBuilder,
PayloadMethodService payloadMethodService,
HeaderClassExtractor headerClassExtractor,
ComponentsService componentsService) {
ComponentsService componentsService,
List<OperationCustomizer> customizers) {
super(
classAnnotationClass,
methodAnnotationClass,
Expand All @@ -46,6 +50,7 @@ public SpringAnnotationClassLevelOperationsScanner(
payloadMethodService,
headerClassExtractor,
componentsService);
this.customizers = customizers;
}

@Override
Expand All @@ -71,6 +76,7 @@ private Stream<Map.Entry<String, Operation>> mapClassToOperation(Class<?> compon
"_", ReferenceUtil.toValidId(channelName), OperationAction.RECEIVE, component.getSimpleName());

Operation operation = buildOperation(classAnnotation, annotatedMethods);
annotatedMethods.forEach(method -> customizers.forEach(customizer -> customizer.customize(operation, method)));

return Stream.of(Map.entry(operationId, operation));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.github.springwolf.core.asyncapi.scanners.common.AsyncAnnotationScanner;
import io.github.springwolf.core.asyncapi.scanners.common.payload.PayloadAsyncOperationService;
import io.github.springwolf.core.asyncapi.scanners.operations.annotations.AsyncAnnotationOperationsScanner;
import io.github.springwolf.core.asyncapi.scanners.operations.annotations.OperationCustomizer;
import io.github.springwolf.core.configuration.docket.AsyncApiDocketService;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
Expand Down Expand Up @@ -97,14 +98,16 @@ public AsyncAnnotationOperationsScanner<AsyncListener> asyncListenerAnnotationOp
ComponentsService componentsService,
PayloadAsyncOperationService payloadService,
List<OperationBindingProcessor> operationBindingProcessors,
List<MessageBindingProcessor> messageBindingProcessors) {
List<MessageBindingProcessor> messageBindingProcessors,
List<OperationCustomizer> operationCustomizers) {
return new AsyncAnnotationOperationsScanner<>(
buildAsyncListenerAnnotationProvider(),
springwolfClassScanner,
componentsService,
payloadService,
operationBindingProcessors,
messageBindingProcessors);
messageBindingProcessors,
operationCustomizers);
}

@Bean
Expand Down Expand Up @@ -141,14 +144,16 @@ public AsyncAnnotationOperationsScanner<AsyncPublisher> asyncPublisherOperationA
ComponentsService componentsService,
PayloadAsyncOperationService payloadService,
List<OperationBindingProcessor> operationBindingProcessors,
List<MessageBindingProcessor> messageBindingProcessors) {
List<MessageBindingProcessor> messageBindingProcessors,
List<OperationCustomizer> customizers) {
return new AsyncAnnotationOperationsScanner<>(
buildAsyncPublisherAnnotationProvider(),
springwolfClassScanner,
componentsService,
payloadService,
operationBindingProcessors,
messageBindingProcessors);
messageBindingProcessors,
customizers);
}

private static AsyncAnnotationScanner.AsyncAnnotationProvider<AsyncListener>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class AsyncAnnotationOperationsScannerTest {
Expand Down Expand Up @@ -95,6 +96,7 @@ public OperationAction getOperationType() {
private final List<OperationBindingProcessor> operationBindingProcessors =
List.of(new TestOperationBindingProcessor());
private final List<MessageBindingProcessor> messageBindingProcessors = emptyList();
private final OperationCustomizer operationCustomizer = mock(OperationCustomizer.class);

private final StringValueResolver stringValueResolver = mock(StringValueResolver.class);

Expand All @@ -105,7 +107,8 @@ public OperationAction getOperationType() {
componentsService,
payloadAsyncOperationService,
operationBindingProcessors,
messageBindingProcessors);
messageBindingProcessors,
List.of(operationCustomizer));

@BeforeEach
public void setup() {
Expand Down Expand Up @@ -336,6 +339,18 @@ void scan_componentHasAsyncMethodAnnotationInAbstractClass() {
.containsExactly(Map.entry("abstract-test-channel_send_methodWithAnnotation", expectedOperation));
}

@Test
void operationCustomizerIsCalled() {
// given
setClassToScan(ClassWithListenerAnnotation.class);

// when
operationsScanner.scan();

// then
verify(operationCustomizer).customize(any(), any());
}

private static class ClassWithoutListenerAnnotation {

private void methodWithoutAnnotation() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import io.github.springwolf.core.asyncapi.scanners.common.headers.HeaderClassExtractor;
import io.github.springwolf.core.asyncapi.scanners.common.payload.NamedSchemaObject;
import io.github.springwolf.core.asyncapi.scanners.common.payload.PayloadMethodService;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

Expand All @@ -38,6 +36,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class SpringAnnotationClassLevelOperationsScannerTest {
Expand All @@ -46,6 +45,7 @@ class SpringAnnotationClassLevelOperationsScannerTest {
private final HeaderClassExtractor headerClassExtractor = mock(HeaderClassExtractor.class);
private final BindingFactory<TestClassListener> bindingFactory = mock(BindingFactory.class);
private final ComponentsService componentsService = mock(ComponentsService.class);
private final OperationCustomizer operationCustomizer = mock(OperationCustomizer.class);
SpringAnnotationClassLevelOperationsScanner<TestClassListener, TestMethodListener> scanner =
new SpringAnnotationClassLevelOperationsScanner<>(
TestClassListener.class,
Expand All @@ -54,7 +54,8 @@ class SpringAnnotationClassLevelOperationsScannerTest {
new AsyncHeadersNotDocumented(),
payloadMethodService,
headerClassExtractor,
componentsService);
componentsService,
List.of(operationCustomizer));

private static final String CHANNEL_ID = "test-channel-id";
private static final Map<String, OperationBinding> defaultOperationBinding =
Expand Down Expand Up @@ -111,6 +112,15 @@ void scan_componentHasTestListenerMethods() {
assertThat(operations).containsExactly(Map.entry(operationName, expectedOperation));
}

@Test
void operationCustomizerIsCalled() {
// when
scanner.scan(ClassWithTestListenerAnnotation.class).toList();

// then
verify(operationCustomizer).customize(any(), any());
}

@TestClassListener
private static class ClassWithTestListenerAnnotation {
@TestMethodListener
Expand All @@ -119,13 +129,6 @@ private void methodWithAnnotation(String payload) {}
private void methodWithoutAnnotation() {}
}

@Data
@NoArgsConstructor
private static class SimpleFoo {
private String s;
private boolean b;
}

@Retention(RetentionPolicy.RUNTIME)
@interface TestClassListener {}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// SPDX-License-Identifier: Apache-2.0
package io.github.springwolf.core.asyncapi.scanners.operations.annotations;

import io.github.springwolf.asyncapi.v3.bindings.ChannelBinding;
import io.github.springwolf.asyncapi.v3.bindings.MessageBinding;
import io.github.springwolf.asyncapi.v3.bindings.OperationBinding;
import io.github.springwolf.asyncapi.v3.bindings.amqp.AMQPChannelBinding;
import io.github.springwolf.asyncapi.v3.bindings.amqp.AMQPMessageBinding;
import io.github.springwolf.asyncapi.v3.bindings.amqp.AMQPOperationBinding;
import io.github.springwolf.asyncapi.v3.model.channel.ChannelReference;
import io.github.springwolf.asyncapi.v3.model.channel.message.MessageHeaders;
import io.github.springwolf.asyncapi.v3.model.channel.message.MessageObject;
import io.github.springwolf.asyncapi.v3.model.channel.message.MessagePayload;
import io.github.springwolf.asyncapi.v3.model.channel.message.MessageReference;
import io.github.springwolf.asyncapi.v3.model.operation.Operation;
import io.github.springwolf.asyncapi.v3.model.operation.OperationAction;
import io.github.springwolf.asyncapi.v3.model.schema.MultiFormatSchema;
import io.github.springwolf.asyncapi.v3.model.schema.SchemaObject;
import io.github.springwolf.asyncapi.v3.model.schema.SchemaReference;
import io.github.springwolf.core.asyncapi.components.ComponentsService;
import io.github.springwolf.core.asyncapi.scanners.bindings.BindingFactory;
import io.github.springwolf.core.asyncapi.scanners.common.headers.AsyncHeadersNotDocumented;
import io.github.springwolf.core.asyncapi.scanners.common.headers.HeaderClassExtractor;
import io.github.springwolf.core.asyncapi.scanners.common.payload.NamedSchemaObject;
import io.github.springwolf.core.asyncapi.scanners.common.payload.PayloadMethodParameterService;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class SpringAnnotationMethodLevelOperationsScannerTest {

private final PayloadMethodParameterService payloadMethodParameterService = mock();
private final HeaderClassExtractor headerClassExtractor = mock(HeaderClassExtractor.class);
private final BindingFactory<TestMethodListener> bindingFactory = mock(BindingFactory.class);
private final ComponentsService componentsService = mock(ComponentsService.class);
private final OperationCustomizer operationCustomizer = mock(OperationCustomizer.class);
SpringAnnotationMethodLevelOperationsScanner<TestMethodListener> scanner =
new SpringAnnotationMethodLevelOperationsScanner<>(
TestMethodListener.class,
bindingFactory,
new AsyncHeadersNotDocumented(),
List.of(operationCustomizer),
payloadMethodParameterService,
headerClassExtractor,
componentsService);

private static final String CHANNEL_ID = "test-channel-id";
private static final Map<String, OperationBinding> defaultOperationBinding =
Map.of("protocol", new AMQPOperationBinding());
private static final Map<String, MessageBinding> defaultMessageBinding =
Map.of("protocol", new AMQPMessageBinding());
private static final Map<String, ChannelBinding> defaultChannelBinding =
Map.of("protocol", new AMQPChannelBinding());

@BeforeEach
void setUp() {
// when
when(bindingFactory.getChannelName(any())).thenReturn(CHANNEL_ID);

doReturn(defaultOperationBinding).when(bindingFactory).buildOperationBinding(any());
doReturn(defaultChannelBinding).when(bindingFactory).buildChannelBinding(any());
doReturn(defaultMessageBinding).when(bindingFactory).buildMessageBinding(any(), any());

when(payloadMethodParameterService.extractSchema(any()))
.thenReturn(new NamedSchemaObject(String.class.getName(), new SchemaObject()));
doAnswer(invocation -> AsyncHeadersNotDocumented.NOT_DOCUMENTED.getTitle())
.when(componentsService)
.registerSchema(any(SchemaObject.class));
}

@Test
void scan_componentHasTestListenerMethods() {
// when
List<Map.Entry<String, Operation>> operations =
scanner.scan(ClassWithTestListenerAnnotation.class).toList();

// then
MessagePayload payload = MessagePayload.of(MultiFormatSchema.builder()
.schema(SchemaReference.fromSchema(String.class.getSimpleName()))
.build());

MessageObject message = MessageObject.builder()
.messageId(String.class.getName())
.name(String.class.getName())
.title(String.class.getSimpleName())
.payload(payload)
.headers(MessageHeaders.of(
MessageReference.toSchema(AsyncHeadersNotDocumented.NOT_DOCUMENTED.getTitle())))
.bindings(defaultMessageBinding)
.build();

Operation expectedOperation = Operation.builder()
.action(OperationAction.RECEIVE)
.channel(ChannelReference.fromChannel(CHANNEL_ID))
.messages(List.of(MessageReference.toChannelMessage(CHANNEL_ID, message)))
.bindings(Map.of("protocol", AMQPOperationBinding.builder().build()))
.build();
String operationName = CHANNEL_ID + "_receive_methodWithAnnotation";
assertThat(operations).containsExactly(Map.entry(operationName, expectedOperation));
}

@Test
void operationCustomizerIsCalled() {
// when
scanner.scan(ClassWithTestListenerAnnotation.class).toList();

// then
verify(operationCustomizer).customize(any(), any());
}

private static class ClassWithTestListenerAnnotation {
@TestMethodListener
private void methodWithAnnotation(String payload) {}

private void methodWithoutAnnotation() {}
}

@Retention(RetentionPolicy.RUNTIME)
@interface TestMethodListener {}
}
Loading

0 comments on commit ef4c28c

Please sign in to comment.