diff --git a/s3stream/src/main/java/com/automq/stream/s3/wal/impl/block/BlockWALService.java b/s3stream/src/main/java/com/automq/stream/s3/wal/impl/block/BlockWALService.java index 6111bb44ae..06ca1f8458 100644 --- a/s3stream/src/main/java/com/automq/stream/s3/wal/impl/block/BlockWALService.java +++ b/s3stream/src/main/java/com/automq/stream/s3/wal/impl/block/BlockWALService.java @@ -46,6 +46,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; import java.util.function.Function; import org.apache.commons.lang3.StringUtils; @@ -116,7 +117,8 @@ public class BlockWALService implements WriteAheadLog { public static final int WAL_HEADER_CAPACITY = WALUtil.BLOCK_SIZE; public static final int WAL_HEADER_TOTAL_CAPACITY = WAL_HEADER_CAPACITY * WAL_HEADER_COUNT; private static final Logger LOGGER = LoggerFactory.getLogger(BlockWALService.class); - private final AtomicBoolean started = new AtomicBoolean(false); + @SuppressWarnings("checkstyle:MemberName") + private final AtomicReference state = new AtomicReference<>(WalState.INIT); private final AtomicBoolean resetFinished = new AtomicBoolean(false); private final AtomicLong writeHeaderRoundTimes = new AtomicLong(0); private final ExecutorService walHeaderFlusher = Threads.newFixedThreadPool(1, ThreadUtils.createThreadFactory("flush-wal-header-thread-%d", true), LOGGER); @@ -276,10 +278,36 @@ private void parseRecordBody(long recoverStartOffset, RecordHeader readRecordHea @Override public WriteAheadLog start() throws IOException { - if (started.get()) { - LOGGER.warn("block WAL service already started"); - return this; + switch (state.get()) { + case INIT: + if (state.compareAndSet(WalState.INIT, WalState.STARTING)) { + try { + doStart(); + state.set(WalState.STARTED); + } finally { + if (state.get() != WalState.STARTED) { + state.compareAndSet(WalState.STARTING, WalState.INIT); + LOGGER.warn("block WAL service started fail"); + } + } + } + break; + case STARTING: + LOGGER.warn("block WAL service is starting"); + break; + case STARTED: + LOGGER.warn("block WAL service already started"); + break; + case SHUTTING_DOWN: + case SHUTDOWN: + throw new IllegalStateException("block WAL service already shutdown"); + default: + throw new IllegalStateException("invalid WAL state"); } + return this; + } + + public void doStart() throws IOException { StopWatch stopWatch = StopWatch.createStarted(); walChannel.open(channel -> Optional.ofNullable(tryReadWALHeader(walChannel)) @@ -298,10 +326,7 @@ public WriteAheadLog start() throws IOException { header.setShutdownType(ShutdownType.UNGRACEFULLY); walHeaderReady(header); - - started.set(true); LOGGER.info("block WAL service started, cost: {} ms", stopWatch.getTime(TimeUnit.MILLISECONDS)); - return this; } private void registerMetrics() { @@ -373,12 +398,38 @@ private void walHeaderReady(BlockWALHeader header) { @Override public void shutdownGracefully() { - StopWatch stopWatch = StopWatch.createStarted(); + for (; ; ) { + WalState state = this.state.get(); + if (state == WalState.SHUTDOWN || this.state.compareAndSet(WalState.INIT, WalState.SHUTDOWN)) { + LOGGER.warn("block WAL service already shutdown or not started yet"); + return; + } + if (state == WalState.STARTING) { + Thread.yield(); + continue; + } + if (state == WalState.SHUTTING_DOWN + || this.state.compareAndSet(state, WalState.SHUTTING_DOWN)) { + break; + } + } - if (!started.getAndSet(false)) { - LOGGER.warn("block WAL service already shutdown or not started yet"); - return; + if (state.compareAndSet(WalState.SHUTTING_DOWN, WalState.SHUTDOWN)) { + boolean success = false; + try { + doShutdown(); + success = true; + } finally { + if (!success) { + LOGGER.warn("block WAL service shutdown fail"); + state.compareAndSet(WalState.SHUTDOWN, WalState.SHUTTING_DOWN); + } + } } + } + + private void doShutdown() { + StopWatch stopWatch = StopWatch.createStarted(); walHeaderFlusher.shutdown(); try { if (!walHeaderFlusher.awaitTermination(5, TimeUnit.SECONDS)) { @@ -509,7 +560,7 @@ private CompletableFuture trim(long offset, boolean internal) { } private void checkStarted() { - if (!started.get()) { + if (state.get() != WalState.STARTED) { throw new IllegalStateException("WriteAheadLog has not been started yet"); } } @@ -530,6 +581,14 @@ private SlidingWindowService.WALHeaderFlusher flusher() { return () -> flushWALHeader(ShutdownType.UNGRACEFULLY); } + private enum WalState { + INIT, + STARTING, + STARTED, + SHUTTING_DOWN, + SHUTDOWN, + } + public static class BlockWALServiceBuilder { private final String blockDevicePath; private long blockDeviceCapacityWant = CAPACITY_NOT_SET; diff --git a/s3stream/src/test/java/com/automq/stream/s3/failover/FailoverTest.java b/s3stream/src/test/java/com/automq/stream/s3/failover/FailoverTest.java index 14b4db1e3b..f73d7c652a 100644 --- a/s3stream/src/test/java/com/automq/stream/s3/failover/FailoverTest.java +++ b/s3stream/src/test/java/com/automq/stream/s3/failover/FailoverTest.java @@ -61,7 +61,8 @@ public void test() throws IOException, ExecutionException, InterruptedException, request.setDevice(path); request.setVolumeId("test_volume_id"); - when(failoverFactory.getWal(any())).thenReturn(BlockWALService.builder(path, 1024 * 1024).nodeId(233).epoch(100).build()); + when(failoverFactory.getWal(any())).thenAnswer(s -> + BlockWALService.builder(path, 1024 * 1024).nodeId(233).epoch(100).build()); boolean exceptionThrown = false; try {