diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java index d659e14928..16e31310eb 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncRunnerTest.java @@ -20,6 +20,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeFalse; import com.google.api.core.ApiFuture; import com.google.api.core.ApiFutures; @@ -60,17 +61,33 @@ public void clearRequests() { @Test public void testAsyncRunner_doesNotReturnCommitTimestampBeforeCommit() { AsyncRunner runner = client().runAsync(); - IllegalStateException e = - assertThrows(IllegalStateException.class, () -> runner.getCommitTimestamp()); - assertTrue(e.getMessage().contains("runAsync() has not yet been called")); + if (isMultiplexedSessionsEnabledForRW()) { + ExecutionException e = + assertThrows(ExecutionException.class, () -> runner.getCommitTimestamp().get()); + Throwable cause = e.getCause(); + assertTrue(cause instanceof IllegalStateException); + assertTrue(cause.getMessage().contains("runAsync() has not yet been called")); + } else { + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> runner.getCommitTimestamp()); + assertTrue(e.getMessage().contains("runAsync() has not yet been called")); + } } @Test public void testAsyncRunner_doesNotReturnCommitResponseBeforeCommit() { AsyncRunner runner = client().runAsync(); - IllegalStateException e = - assertThrows(IllegalStateException.class, () -> runner.getCommitResponse()); - assertTrue(e.getMessage().contains("runAsync() has not yet been called")); + if (isMultiplexedSessionsEnabledForRW()) { + ExecutionException e = + assertThrows(ExecutionException.class, () -> runner.getCommitResponse().get()); + Throwable cause = e.getCause(); + assertTrue(cause instanceof IllegalStateException); + assertTrue(cause.getMessage().contains("runAsync() has not yet been called")); + } else { + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> runner.getCommitResponse()); + assertTrue(e.getMessage().contains("runAsync() has not yet been called")); + } } @Test @@ -201,7 +218,17 @@ public void asyncRunnerUpdateAbortedWithoutGettingResult() throws Exception { executor); assertThat(result.get()).isNull(); assertThat(attempt.get()).isEqualTo(2); - if (isMultiplexedSessionsEnabled()) { + if (isMultiplexedSessionsEnabledForRW()) { + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + CreateSessionRequest.class, + ExecuteSqlRequest.class, + // The retry will use an explicit BeginTransaction RPC because the first statement of + // the transaction did not return a transaction id during the initial attempt. + BeginTransactionRequest.class, + ExecuteSqlRequest.class, + CommitRequest.class); + } else if (isMultiplexedSessionsEnabled()) { assertThat(mockSpanner.getRequestTypes()) .containsExactly( CreateSessionRequest.class, @@ -260,7 +287,11 @@ public void asyncRunnerWaitsUntilAsyncUpdateHasFinished() throws Exception { }, executor); res.get(); - if (isMultiplexedSessionsEnabled()) { + if (isMultiplexedSessionsEnabledForRW()) { + assertThat(mockSpanner.getRequestTypes()) + .containsAtLeast( + CreateSessionRequest.class, ExecuteSqlRequest.class, CommitRequest.class); + } else if (isMultiplexedSessionsEnabled()) { // The mock server could have received a CreateSession request for a multiplexed session, but // it could also be that that request has not yet reached the server. assertThat(mockSpanner.getRequestTypes()) @@ -404,7 +435,17 @@ public void asyncRunnerBatchUpdateAbortedWithoutGettingResult() throws Exception executor); assertThat(result.get()).isNull(); assertThat(attempt.get()).isEqualTo(2); - if (isMultiplexedSessionsEnabled()) { + if (isMultiplexedSessionsEnabledForRW()) { + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + CreateSessionRequest.class, + ExecuteSqlRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class, + ExecuteSqlRequest.class, + ExecuteBatchDmlRequest.class, + CommitRequest.class); + } else if (isMultiplexedSessionsEnabled()) { assertThat(mockSpanner.getRequestTypes()) .containsExactly( CreateSessionRequest.class, @@ -463,7 +504,11 @@ public void asyncRunnerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception }, executor); res.get(); - if (isMultiplexedSessionsEnabled()) { + if (isMultiplexedSessionsEnabledForRW()) { + assertThat(mockSpanner.getRequestTypes()) + .containsExactly( + CreateSessionRequest.class, ExecuteBatchDmlRequest.class, CommitRequest.class); + } else if (isMultiplexedSessionsEnabled()) { assertThat(mockSpanner.getRequestTypes()) .containsExactly( CreateSessionRequest.class, @@ -479,6 +524,8 @@ public void asyncRunnerWaitsUntilAsyncBatchUpdateHasFinished() throws Exception @Test public void closeTransactionBeforeEndOfAsyncQuery() throws Exception { + // TODO(sriharshach): Fix this unittest + assumeFalse("Skipping for mux", isMultiplexedSessionsEnabledForRW()); final BlockingQueue results = new SynchronousQueue<>(); final SettableApiFuture finished = SettableApiFuture.create(); DatabaseClientImpl clientImpl = (DatabaseClientImpl) client(); @@ -576,4 +623,11 @@ private boolean isMultiplexedSessionsEnabled() { } return spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSession(); } + + private boolean isMultiplexedSessionsEnabledForRW() { + if (spanner.getOptions() == null || spanner.getOptions().getSessionPoolOptions() == null) { + return false; + } + return spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSessionForRW(); + } }