Skip to content

Commit

Permalink
revise: stomp session
Browse files Browse the repository at this point in the history
  • Loading branch information
geneaky committed Mar 20, 2024
1 parent dc010e2 commit e6c8609
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 148 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,41 +1,23 @@
package toy.bookchat.bookchat.config.websocket;

import java.util.List;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.config.annotation.web.socket.AbstractSecurityWebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketTransportRegistration;
import toy.bookchat.bookchat.domain.participant.repository.ParticipantRepository;

@EnableWebSocketMessageBroker
@Configuration(proxyBeanMethods = false)
public class MessageBrokerSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {

private final ChannelInterceptor webSocketTokenValidationInterceptor;
private final MessageAuthenticationArgumentResolver messageAuthenticationArgumentResolver;
private final ExternalBrokerProperties externalBrokerProperties;
private final ParticipantRepository participantRepository;

public MessageBrokerSecurityConfig(ChannelInterceptor webSocketTokenValidationInterceptor,
MessageAuthenticationArgumentResolver messageAuthenticationArgumentResolver,
ExternalBrokerProperties externalBrokerProperties,
ParticipantRepository participantRepository) {
public MessageBrokerSecurityConfig(ChannelInterceptor webSocketTokenValidationInterceptor, ExternalBrokerProperties externalBrokerProperties) {
this.webSocketTokenValidationInterceptor = webSocketTokenValidationInterceptor;
this.messageAuthenticationArgumentResolver = messageAuthenticationArgumentResolver;
this.externalBrokerProperties = externalBrokerProperties;
this.participantRepository = participantRepository;
}

@Override
public void configureWebSocketTransport(WebSocketTransportRegistration registration) {
registration.addDecoratorFactory(
handler -> new CustomWebSocketHandlerDecorator(participantRepository, handler));
super.configureWebSocketTransport(registration);
}

@Override
Expand All @@ -45,11 +27,6 @@ protected void customizeClientInboundChannel(ChannelRegistration registration) {
registration.taskExecutor().maxPoolSize(10);
}

@Override
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
argumentResolvers.add(messageAuthenticationArgumentResolver);
}

@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/stomp-connection").setAllowedOrigins("*");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package toy.bookchat.bookchat.config.websocket;

import lombok.extern.slf4j.Slf4j;
import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;
import org.springframework.web.socket.messaging.SessionSubscribeEvent;
import toy.bookchat.bookchat.domain.participant.repository.ParticipantRepository;
import toy.bookchat.bookchat.security.user.UserPrincipal;

@Slf4j
@Component
public class StompEventListener {

public static final int TOPIC_NAME_LENGTH = 7;

private final ParticipantRepository participantRepository;

public StompEventListener(ParticipantRepository participantRepository) {
this.participantRepository = participantRepository;
}

@EventListener
public void handleStompConnectEvent(SessionConnectedEvent event) {
UsernamePasswordAuthenticationToken user = (UsernamePasswordAuthenticationToken) event.getUser();
UserPrincipal userPrincipal = (UserPrincipal) user.getPrincipal();
log.info("Stomp Connect Event :: {}", userPrincipal.getUsername());
}

@EventListener
public void handleStompDisconnectEvent(SessionDisconnectEvent event) {
UsernamePasswordAuthenticationToken user = (UsernamePasswordAuthenticationToken) event.getUser();
UserPrincipal userPrincipal = (UserPrincipal) user.getPrincipal();
participantRepository.disconnectAllByUserId(userPrincipal.getUserId());
log.info("Stomp Disconnect Event :: {}", userPrincipal.getUsername());
}

@EventListener
public void handleStompSubscribeEvent(SessionSubscribeEvent event) {
UsernamePasswordAuthenticationToken user = (UsernamePasswordAuthenticationToken) event.getUser();
UserPrincipal userPrincipal = (UserPrincipal) user.getPrincipal();
StompHeaderAccessor accessor = StompHeaderAccessor.wrap(event.getMessage());
participantRepository.connect(userPrincipal.getUserId(), accessor.getDestination().substring(TOPIC_NAME_LENGTH));
log.info("Stomp Subscribe Event :: {}", userPrincipal.getUsername());
}

@EventListener
public void handleStompUnsubscribeEvent(SessionSubscribeEvent event) {
UsernamePasswordAuthenticationToken user = (UsernamePasswordAuthenticationToken) event.getUser();
UserPrincipal userPrincipal = (UserPrincipal) user.getPrincipal();
String destination = StompHeaderAccessor.wrap(event.getMessage()).getDestination().substring(TOPIC_NAME_LENGTH);
participantRepository.disconnect(userPrincipal.getUserId(), destination);
log.info("Stomp Unsubscribe Event :: {}", userPrincipal.getUsername());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import static org.springframework.http.HttpHeaders.AUTHORIZATION;
import static org.springframework.messaging.simp.stomp.StompCommand.CONNECT;
import static org.springframework.messaging.simp.stomp.StompCommand.SEND;
import static org.springframework.messaging.simp.stomp.StompCommand.SUBSCRIBE;
import static org.springframework.messaging.simp.stomp.StompCommand.UNSUBSCRIBE;

import java.util.List;
import org.springframework.messaging.Message;
Expand All @@ -13,46 +10,25 @@
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.stereotype.Component;
import toy.bookchat.bookchat.domain.participant.repository.ParticipantRepository;
import toy.bookchat.bookchat.security.token.jwt.JwtTokenManager;
import toy.bookchat.bookchat.security.user.TokenPayload;

@Component
public class WebSocketTokenValidationInterceptor implements ChannelInterceptor {

public static final int TOPIC_NAME_LENGTH = 7;
private final JwtTokenManager jwtTokenManager;
private final ParticipantRepository participantRepository;

public WebSocketTokenValidationInterceptor(JwtTokenManager jwtTokenManager,
ParticipantRepository participantRepository) {
public WebSocketTokenValidationInterceptor(JwtTokenManager jwtTokenManager) {
this.jwtTokenManager = jwtTokenManager;
this.participantRepository = participantRepository;
}

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);

if (CONNECT.equals(accessor.getCommand()) || SUBSCRIBE.equals(
accessor.getCommand()) || SEND.equals(accessor.getCommand())) {
if (CONNECT.equals(accessor.getCommand())) {
try {
String bearerToken = jwtTokenManager.extractTokenFromAuthorizationHeader(
getAuthorizationHeader(accessor));
TokenPayload payload = jwtTokenManager.getTokenPayloadFromToken(
bearerToken);

if (SUBSCRIBE.equals(accessor.getCommand())) {
Long userId = payload.getUserId();
String destination = accessor.getDestination().substring(TOPIC_NAME_LENGTH);
participantRepository.connect(userId, destination);
}

if (UNSUBSCRIBE.equals(accessor.getCommand())) {
Long userId = payload.getUserId();
String destination = accessor.getDestination().substring(TOPIC_NAME_LENGTH);
participantRepository.disconnect(userId, destination);
}
String bearerToken = jwtTokenManager.extractTokenFromAuthorizationHeader(getAuthorizationHeader(accessor));
jwtTokenManager.getTokenPayloadFromToken(bearerToken);
} catch (Exception exception) {
throw new MessageDeliveryException("Access Denied");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ public ChatController(ChatService chatService) {
}

@MessageMapping("/send/chatrooms/{roomId}")
public void sendMessage(@Valid MessageDto messageDto, @UserPayload TokenPayload tokenPayload,
@DestinationVariable Long roomId) {
public void sendMessage(@Valid MessageDto messageDto, @DestinationVariable Long roomId, @UserPayload TokenPayload tokenPayload) {
chatService.sendMessage(tokenPayload.getUserId(), roomId, messageDto);
}

@GetMapping("/v1/api/chatrooms/{roomId}/chats")
public ChatRoomChatsResponse getChatRoomChats(@PathVariable Long roomId, Long postCursorId,
Pageable pageable, @UserPayload TokenPayload tokenPayload) {
return chatService.getChatRoomChats(roomId, postCursorId, pageable,
tokenPayload.getUserId());
return chatService.getChatRoomChats(roomId, postCursorId, pageable, tokenPayload.getUserId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Optional<Participant> findWithPessimisticLockByUserIdAndChatRoomId(Long userId,
Long countSubHostByRoomId(Long roomId);

@Transactional
void disconnectAll(String name);
void disconnectAllByUserId(Long userId);

@Transactional
void connect(Long userId, String roomSid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,10 @@ public Long countSubHostByRoomId(Long roomId) {
}

@Override
public void disconnectAll(String name) {
public void disconnectAllByUserId(Long userId) {
queryFactory.update(participant)
.set(participant.isConnected, false)
.where(participant.user.id.eq(
JPAExpressions.select(user.id).from(user).where(user.name.eq(name))))
.where(participant.user.id.eq(userId))
.execute();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ public JwtAuthenticationFilter(JwtTokenManager jwtTokenManager) {
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
try {
authentication(request);
} catch (RuntimeException exception) {
Expand All @@ -33,8 +32,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}

private void authentication(HttpServletRequest request) {
String bearerToken = jwtTokenManager.extractTokenFromAuthorizationHeader(
request.getHeader(HttpHeaders.AUTHORIZATION));
String bearerToken = jwtTokenManager.extractTokenFromAuthorizationHeader(request.getHeader(HttpHeaders.AUTHORIZATION));
TokenPayload tokenPayload = jwtTokenManager.getTokenPayloadFromToken(bearerToken);

registerUserAuthenticationOnSecurityContext(tokenPayload);
Expand All @@ -45,7 +43,6 @@ private void registerUserAuthenticationOnSecurityContext(TokenPayload tokenPaylo

SecurityContextHolder
.getContext()
.setAuthentication(new UsernamePasswordAuthenticationToken(userPrincipal, null,
userPrincipal.getAuthorities()));
.setAuthentication(new UsernamePasswordAuthenticationToken(userPrincipal, null, userPrincipal.getAuthorities()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ public boolean isCredentialsNonExpired() {
public boolean isEnabled() {
return true;
}

public Long getUserId() {
return this.tokenPayload.getUserId();
}
}

0 comments on commit e6c8609

Please sign in to comment.