From aae94ffae32b61f5a3a669d27ad94ba743631d3e Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Mon, 27 Jan 2025 19:13:52 -0600 Subject: [PATCH 01/12] Add a timer to waitForTransferArchive --- .../controllers/DeviceController.java | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index ef2639b7b..6936d2d44 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -111,6 +111,10 @@ public class DeviceController { private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration"); + private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME = + MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration"); + + @VisibleForTesting static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32; @@ -565,7 +569,11 @@ public CompletionStage waitForTransferArchive(@ReadOnly @Auth final Au description = """ The amount of time (in seconds) to wait for a response. If a transfer archive for the authenticated device is not available within the given amount of time, this endpoint will return a status of HTTP/204. - """) final int timeoutSeconds) { + """) final int timeoutSeconds, + + @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) { + + final Timer.Sample sample = Timer.start(); final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) + ":" + authenticatedDevice.getAuthenticatedDevice().getId(); @@ -575,7 +583,20 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo authenticatedDevice.getAuthenticatedDevice(), Duration.ofSeconds(timeoutSeconds))) .thenApply(maybeTransferArchive -> maybeTransferArchive - .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())); + .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .whenComplete((response, throwable) -> { + if (response == null) { + return; + } + sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of( + UserAgentTagUtil.getPlatformTag(userAgent), + io.micrometer.core.instrument.Tag.of( + "archiveUploaded", + String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) + .register(Metrics.globalRegistry)); + }); } } From 1446d1acf813ffe450979154aa53d7f586a92f31 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Mon, 27 Jan 2025 19:24:47 -0600 Subject: [PATCH 02/12] Fix blocking call in waitForLinkedDevice --- .../controllers/DeviceController.java | 59 +++++++++---------- .../controllers/DeviceControllerTest.java | 11 +++- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 6936d2d44..291d7bc92 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -343,7 +343,7 @@ public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) Bas @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting") @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") - public CompletableFuture waitForLinkedDevice( + public CompletionStage waitForLinkedDevice( @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, @PathParam("tokenIdentifier") @@ -363,40 +363,35 @@ The amount of time (in seconds) to wait for a response. If the expected device i given amount of time, this endpoint will return a status of HTTP/204. """) final int timeoutSeconds, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { - - rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)); - + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); linkedDeviceListenerCounter.incrementAndGet(); - final Timer.Sample sample = Timer.start(); - try { - return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(), - authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) - .thenApply(maybeDeviceInfo -> maybeDeviceInfo - .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, - e -> Response.status(Response.Status.BAD_REQUEST).build())) - .whenComplete((response, throwable) -> { - linkedDeviceListenerCounter.decrementAndGet(); - - if (response != null) { - sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of("deviceFound", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - } - }); - } catch (final RedisException e) { - // `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift - // if that happens - linkedDeviceListenerCounter.decrementAndGet(); - throw e; - } + + return rateLimiters.getWaitForLinkedDeviceLimiter() + .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) + .thenCompose(ignored -> accounts.waitForNewLinkedDevice( + authenticatedDevice.getAccount().getUuid(), + authenticatedDevice.getAuthenticatedDevice(), + tokenIdentifier, + Duration.ofSeconds(timeoutSeconds))) + .thenApply(maybeDeviceInfo -> maybeDeviceInfo + .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, + e -> Response.status(Response.Status.BAD_REQUEST).build())) + .whenComplete((response, throwable) -> { + linkedDeviceListenerCounter.decrementAndGet(); + + if (response != null) { + sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + io.micrometer.core.instrument.Tag.of("deviceFound", + String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) + .register(Metrics.globalRegistry)); + } + }); } private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 2a302c4a4..c735cec6e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -955,6 +955,8 @@ void waitForLinkedDevice() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -979,6 +981,8 @@ void waitForLinkedDeviceNoDeviceLinked() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -997,6 +1001,8 @@ void waitForLinkedDeviceBadTokenIdentifier() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -1042,10 +1048,11 @@ private static List waitForLinkedDeviceBadTokenIdentifierLength() { } @Test - void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { + void waitForLinkedDeviceRateLimited() { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); - doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) From 282bcf6f343fa9c1d00dd115c2286c0c8a2f18e8 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Tue, 28 Jan 2025 13:51:07 -0600 Subject: [PATCH 03/12] Add persistent timer utility backed by redis --- .../textsecuregcm/WhisperServerService.java | 4 +- .../controllers/DeviceController.java | 89 ++++++++------- .../storage/PersistentTimer.java | 104 ++++++++++++++++++ .../controllers/DeviceControllerTest.java | 7 +- .../storage/PersistentTimerTest.java | 100 +++++++++++++++++ 5 files changed, 257 insertions(+), 47 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 25f280d2c..61def05fb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -227,6 +227,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.ProfilesManager; @@ -1097,6 +1098,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { log.info("Registered spam filter: {}", filter.getClass().getName()); }); + final PersistentTimer persistentTimer = new PersistentTimer(rateLimitersCluster, clock); final PhoneVerificationTokenManager phoneVerificationTokenManager = new PhoneVerificationTokenManager( phoneNumberIdentifiers, registrationServiceClient, registrationRecoveryPasswordsManager, registrationRecoveryChecker); @@ -1115,7 +1117,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock), new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker), - new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, config.getMaxDevices()), + new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, persistentTimer, config.getMaxDevices()), new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters, config.getDeviceCheck().backupRedemptionLevel(), config.getDeviceCheck().backupRedemptionDuration()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 291d7bc92..1a90916e8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -7,7 +7,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; -import io.lettuce.core.RedisException; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; @@ -81,6 +80,7 @@ import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter; import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.ExceptionUtils; @@ -100,6 +100,7 @@ public class DeviceController { private final AccountsManager accounts; private final ClientPublicKeysManager clientPublicKeysManager; private final RateLimiters rateLimiters; + private final PersistentTimer persistentTimer; private final Map maxDeviceConfiguration; private final EnumMap linkedDeviceListenersByPlatform; @@ -108,9 +109,11 @@ public class DeviceController { private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME = MetricsUtil.name(DeviceController.class, "linkedDeviceListeners"); + private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE = "wait_for_linked_device"; private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration"); + private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE = "wait_for_transfer_archive"; private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration"); @@ -124,11 +127,13 @@ public class DeviceController { public DeviceController(final AccountsManager accounts, final ClientPublicKeysManager clientPublicKeysManager, final RateLimiters rateLimiters, + final PersistentTimer persistentTimer, final Map maxDeviceConfiguration) { this.accounts = accounts; this.clientPublicKeysManager = clientPublicKeysManager; this.rateLimiters = rateLimiters; + this.persistentTimer = persistentTimer; this.maxDeviceConfiguration = maxDeviceConfiguration; linkedDeviceListenersByPlatform = @@ -366,32 +371,30 @@ The amount of time (in seconds) to wait for a response. If the expected device i @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); linkedDeviceListenerCounter.incrementAndGet(); - final Timer.Sample sample = Timer.start(); return rateLimiters.getWaitForLinkedDeviceLimiter() .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) - .thenCompose(ignored -> accounts.waitForNewLinkedDevice( - authenticatedDevice.getAccount().getUuid(), - authenticatedDevice.getAuthenticatedDevice(), - tokenIdentifier, - Duration.ofSeconds(timeoutSeconds))) - .thenApply(maybeDeviceInfo -> maybeDeviceInfo - .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, - e -> Response.status(Response.Status.BAD_REQUEST).build())) - .whenComplete((response, throwable) -> { - linkedDeviceListenerCounter.decrementAndGet(); - - if (response != null) { - sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of("deviceFound", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - } - }); + .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier)) + .thenCompose(sample -> accounts.waitForNewLinkedDevice( + authenticatedDevice.getAccount().getUuid(), + authenticatedDevice.getAuthenticatedDevice(), + tokenIdentifier, + Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeDeviceInfo -> maybeDeviceInfo + .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, + e -> Response.status(Response.Status.BAD_REQUEST).build())) + .whenComplete((response, throwable) -> { + linkedDeviceListenerCounter.decrementAndGet(); + + if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { + sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) + .register(Metrics.globalRegistry)); + } + })); } private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { @@ -529,7 +532,8 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo public CompletionStage recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) { - return rateLimiters.getUploadTransferArchiveLimiter().validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) + return rateLimiters.getUploadTransferArchiveLimiter() + .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) .thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(), transferArchiveUploadedRequest.destinationDeviceId(), Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), @@ -568,30 +572,25 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) { - final Timer.Sample sample = Timer.start(); final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) + ":" + authenticatedDevice.getAuthenticatedDevice().getId(); return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey) - .thenCompose(ignored -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), - authenticatedDevice.getAuthenticatedDevice(), - Duration.ofSeconds(timeoutSeconds))) - .thenApply(maybeTransferArchive -> maybeTransferArchive - .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .whenComplete((response, throwable) -> { - if (response == null) { - return; - } - sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of( - UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of( - "archiveUploaded", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - }); + .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey)) + .thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), + authenticatedDevice.getAuthenticatedDevice(), + Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeTransferArchive -> maybeTransferArchive + .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .whenComplete((response, throwable) -> { + if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { + sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) + .register(Metrics.globalRegistry)); + } + })); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java new file mode 100644 index 000000000..105ffb61b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java @@ -0,0 +1,104 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import io.lettuce.core.SetArgs; +import io.micrometer.core.instrument.Timer; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.util.Util; + +/** + * Timers for operations that may span machines or requests and require a persistently stored timer start itme + */ +public class PersistentTimer { + + private static final Logger logger = LoggerFactory.getLogger(PersistentTimer.class); + + private static String TIMER_NAMESPACE = "persistent_timer"; + @VisibleForTesting + static final Duration TIMER_TTL = Duration.ofHours(1); + + private final FaultTolerantRedisClusterClient redisClient; + private final Clock clock; + + + public PersistentTimer(final FaultTolerantRedisClusterClient redisClient, final Clock clock) { + this.redisClient = redisClient; + this.clock = clock; + } + + public class Sample { + + private final Instant start; + private final String redisKey; + + public Sample(final Instant start, final String redisKey) { + this.start = start; + this.redisKey = redisKey; + } + + /** + * Stop the timer, recording the duration between now and the first call to start. This deletes the persistent timer. + * + * @param timer The micrometer timer to record the duration to + * @return A future that completes when the resources associated with the persistent timer have been destroyed + */ + public CompletableFuture stop(Timer timer) { + Duration duration = Duration.between(start, clock.instant()); + timer.record(duration); + return redisClient.withCluster(connection -> connection.async().del(redisKey)) + .toCompletableFuture() + .thenRun(Util.NOOP); + } + } + + /** + * Start the timer if a timer with the provided namespaced key has not already been started, otherwise return the + * existing sample. + * + * @param namespace A namespace prefix to use for the timer + * @param key The unique key within the namespace that identifies the timer + * @return A future that completes with a {@link Sample} that can later be used to record the final duration. + */ + public CompletableFuture start(final String namespace, final String key) { + final Instant now = clock.instant(); + final String redisKey = redisKey(namespace, key); + + return redisClient.withCluster(connection -> + connection.async().setGet(redisKey, String.valueOf(now.getEpochSecond()), SetArgs.Builder.nx().ex(TIMER_TTL))) + .toCompletableFuture() + .thenApply(serialized -> new Sample(parseStoredTimestamp(serialized).orElse(now), redisKey)); + } + + @VisibleForTesting + String redisKey(final String namespace, final String key) { + return String.format("%s::%s::%s", TIMER_NAMESPACE, namespace, key); + } + + private static Optional parseStoredTimestamp(final @Nullable String serialized) { + return Optional + .ofNullable(serialized) + .flatMap(s -> { + try { + return Optional.of(Long.parseLong(s)); + } catch (NumberFormatException e) { + logger.warn("Failed to parse stored timestamp {}", s, e); + return Optional.empty(); + } + }) + .map(Instant::ofEpochSecond); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index c735cec6e..0cac6fce1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -12,7 +12,6 @@ import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.clearInvocations; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -90,6 +89,7 @@ import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -104,6 +104,7 @@ class DeviceControllerTest { private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final ClientPublicKeysManager clientPublicKeysManager = mock(ClientPublicKeysManager.class); + private static final PersistentTimer persistentTimer = mock(PersistentTimer.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class); @SuppressWarnings("unchecked") @@ -123,6 +124,7 @@ class DeviceControllerTest { accountsManager, clientPublicKeysManager, rateLimiters, + persistentTimer, deviceConfiguration); @RegisterExtension @@ -161,6 +163,9 @@ void setup() { when(clientPublicKeysManager.setPublicKey(any(), anyByte(), any())) .thenReturn(CompletableFuture.completedFuture(null)); + when(persistentTimer.start(anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture(mock(PersistentTimer.Sample.class))); + AccountsHelper.setupMockUpdate(accountsManager); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java new file mode 100644 index 000000000..02b5006ec --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.micrometer.core.instrument.Timer; +import java.time.Duration; +import java.time.Instant; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.TestClock; + +class PersistentTimerTest { + + private static final String NAMESPACE = "namespace"; + private static final String KEY = "key"; + + @RegisterExtension + private static final RedisClusterExtension CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + private TestClock clock; + private PersistentTimer timer; + + @BeforeEach + public void setup() { + clock = TestClock.pinned(Instant.ofEpochSecond(10)); + timer = new PersistentTimer(CLUSTER_EXTENSION.getRedisCluster(), clock); + } + + @Test + public void testStop() { + PersistentTimer.Sample sample = timer.start(NAMESPACE, KEY).join(); + final String redisKey = timer.redisKey(NAMESPACE, KEY); + + final String actualStartString = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().get(redisKey)); + final Instant actualStart = Instant.ofEpochSecond(Long.parseLong(actualStartString)); + assertThat(actualStart).isEqualTo(clock.instant()); + + final long ttl = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().ttl(redisKey)); + + assertThat(ttl).isBetween(0L, PersistentTimer.TIMER_TTL.getSeconds()); + + Timer mockTimer = mock(Timer.class); + clock.pin(clock.instant().plus(Duration.ofSeconds(5))); + sample.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(5)); + + final String afterDeletion = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().get(redisKey)); + + assertThat(afterDeletion).isNull(); + } + + @Test + public void testNamespace() { + Timer mockTimer = mock(Timer.class); + + clock.pin(Instant.ofEpochSecond(10)); + PersistentTimer.Sample timer1 = timer.start("n1", KEY).join(); + clock.pin(Instant.ofEpochSecond(20)); + PersistentTimer.Sample timer2 = timer.start("n2", KEY).join(); + clock.pin(Instant.ofEpochSecond(30)); + + timer2.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(10)); + + timer1.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(20)); + } + + @Test + public void testMultipleStart() { + Timer mockTimer = mock(Timer.class); + + clock.pin(Instant.ofEpochSecond(10)); + PersistentTimer.Sample timer1 = timer.start(NAMESPACE, KEY).join(); + clock.pin(Instant.ofEpochSecond(11)); + PersistentTimer.Sample timer2 = timer.start(NAMESPACE, KEY).join(); + clock.pin(Instant.ofEpochSecond(12)); + PersistentTimer.Sample timer3 = timer.start(NAMESPACE, KEY).join(); + + clock.pin(Instant.ofEpochSecond(20)); + timer2.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(10)); + + assertThatNoException().isThrownBy(() -> timer1.stop(mockTimer).join()); + assertThatNoException().isThrownBy(() -> timer3.stop(mockTimer).join()); + } + + +} From 48ada8e8ca59e74b02eb1b61b8c8ee475b7d510c Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:24:50 -0500 Subject: [PATCH 04/12] Clarify roles/responsibilities of components in the message-handling pathway --- .../textsecuregcm/WhisperServerService.java | 6 +- .../auth/UnidentifiedAccessUtil.java | 41 + .../controllers/MessageController.java | 365 ++--- .../MultiRecipientMessageProvider.java | 6 +- .../textsecuregcm/push/MessageSender.java | 97 +- .../textsecuregcm/push/ReceiptSender.java | 17 +- .../storage/ChangeNumberManager.java | 59 +- .../textsecuregcm/storage/MessagesCache.java | 22 +- .../storage/MessagesCacheInsertScript.java | 6 +- ...edMultiRecipientPayloadAndViewsScript.java | 7 +- .../storage/MessagesManager.java | 120 +- .../storage/ReportMessageDynamoDb.java | 18 +- .../storage/ReportMessageManager.java | 5 +- .../util/DestinationDeviceValidator.java | 10 +- .../workers/CommandDependencies.java | 4 +- .../controllers/MessageControllerTest.java | 1239 +++++++---------- .../textsecuregcm/push/MessageSenderTest.java | 81 +- .../storage/ChangeNumberManagerTest.java | 60 +- .../MessagePersisterIntegrationTest.java | 4 +- .../storage/MessagePersisterTest.java | 2 +- .../MessagesCacheGetItemsScriptTest.java | 2 +- .../MessagesCacheInsertScriptTest.java | 14 +- ...ltiRecipientPayloadAndViewsScriptTest.java | 22 +- .../MessagesCacheRemoveByGuidScriptTest.java | 2 +- .../MessagesCacheRemoveQueueScriptTest.java | 2 +- ...oveRecipientViewFromMrmDataScriptTest.java | 7 +- .../storage/MessagesCacheTest.java | 44 +- .../storage/MessagesManagerTest.java | 135 +- .../storage/ReportMessageDynamoDbTest.java | 5 +- .../storage/ReportMessageManagerTest.java | 5 +- .../util/MultiRecipientMessageHelper.java | 92 ++ .../tests/util/TestRecipient.java | 22 + .../WebSocketConnectionIntegrationTest.java | 12 +- 33 files changed, 1336 insertions(+), 1197 deletions(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MultiRecipientMessageHelper.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestRecipient.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 61def05fb..f96108adf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -431,7 +431,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro config.getDynamoDbTables().getRemoteConfig().getTableName()); PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient, config.getDynamoDbTables().getPushChallenge().getTableName()); - ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, + ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient, config.getDynamoDbTables().getReportMessage().getTableName(), config.getReportMessageConfiguration().getReportTtl()); RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords( @@ -618,7 +618,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, config.getReportMessageConfiguration().getCounterTtl()); MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, - messageDeletionAsyncExecutor); + messageDeletionAsyncExecutor, Clock.systemUTC()); AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, config.getDynamoDbTables().getDeletedAccountsLock().getTableName()); ClientPublicKeysManager clientPublicKeysManager = @@ -1128,7 +1128,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { new KeyTransparencyController(keyTransparencyServiceClient), new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender, accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler, - reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, clientReleaseManager, + reportMessageManager, messageDeliveryScheduler, clientReleaseManager, dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor, Clock.systemUTC()), new PaymentsController(currencyManager, paymentsCredentialsGenerator), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java index dccfab6e9..35c6f8c75 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java @@ -7,6 +7,9 @@ import org.whispersystems.textsecuregcm.storage.Account; import java.security.MessageDigest; +import java.util.Collection; +import java.util.function.Predicate; +import java.util.stream.IntStream; public class UnidentifiedAccessUtil { @@ -31,4 +34,42 @@ public static boolean checkUnidentifiedAccess(final Account targetAccount, final .map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey)) .orElse(false); } + + /** + * Checks whether an action (e.g. sending a message or retrieving pre-keys) may be taken on the collection of target + * accounts by an actor presenting the given combined unidentified access key. + * + * @param targetAccounts the accounts on which an actor wishes to take an action + * @param combinedUnidentifiedAccessKey the unidentified access key presented by the actor + * + * @return {@code true} if an actor presenting the given unidentified access key has permission to take an action on + * the target accounts or {@code false} otherwise + */ + public static boolean checkUnidentifiedAccess(final Collection targetAccounts, final byte[] combinedUnidentifiedAccessKey) { + return MessageDigest.isEqual(getCombinedUnidentifiedAccessKey(targetAccounts), combinedUnidentifiedAccessKey); + } + + /** + * Calculates a combined unidentified access key for the given collection of accounts. + * + * @param accounts the accounts from which to derive a combined unidentified access key + * @return a combined unidentified access key + * + * @throws IllegalArgumentException if one or more of the given accounts had an unidentified access key with an + * unexpected length + */ + public static byte[] getCombinedUnidentifiedAccessKey(final Collection accounts) { + return accounts.stream() + .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) + .map(account -> + account.getUnidentifiedAccessKey() + .filter(b -> b.length == UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH) + .orElseThrow(IllegalArgumentException::new)) + .reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH], + (a, b) -> { + final byte[] xor = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]; + IntStream.range(0, UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH).forEach(i -> xor[i] = (byte) (a[i] ^ b[i])); + return xor; + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 3cd71c243..02a57caa1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -9,10 +9,8 @@ import com.codahale.metrics.annotation.Timed; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; -import com.google.protobuf.ByteString; import io.dropwizard.auth.Auth; import io.dropwizard.util.DataSize; -import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; @@ -47,33 +45,26 @@ import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response.Status; -import java.security.MessageDigest; import java.time.Clock; import java.time.Duration; import java.util.ArrayList; -import java.util.Base64; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.function.Predicate; import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.stream.Stream; -import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; -import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.zkgroup.ServerSecretParams; @@ -135,6 +126,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; +import reactor.util.function.Tuple2; import reactor.util.function.Tuples; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -142,14 +134,6 @@ @io.swagger.v3.oas.annotations.tags.Tag(name = "Messages") public class MessageController { - - private record MultiRecipientDeliveryData( - ServiceIdentifier serviceIdentifier, - Account account, - Recipient recipient, - Map deviceIdToRegistrationId) { - } - private static final Logger logger = LoggerFactory.getLogger(MessageController.class); private final RateLimiters rateLimiters; @@ -162,7 +146,6 @@ private record MultiRecipientDeliveryData( private final PushNotificationManager pushNotificationManager; private final PushNotificationScheduler pushNotificationScheduler; private final ReportMessageManager reportMessageManager; - private final ExecutorService multiRecipientMessageExecutor; private final Scheduler messageDeliveryScheduler; private final ClientReleaseManager clientReleaseManager; private final DynamicConfigurationManager dynamicConfigurationManager; @@ -229,7 +212,6 @@ public MessageController( PushNotificationManager pushNotificationManager, PushNotificationScheduler pushNotificationScheduler, ReportMessageManager reportMessageManager, - @Nonnull ExecutorService multiRecipientMessageExecutor, Scheduler messageDeliveryScheduler, final ClientReleaseManager clientReleaseManager, final DynamicConfigurationManager dynamicConfigurationManager, @@ -248,7 +230,6 @@ public MessageController( this.pushNotificationManager = pushNotificationManager; this.pushNotificationScheduler = pushNotificationScheduler; this.reportMessageManager = reportMessageManager; - this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor); this.messageDeliveryScheduler = messageDeliveryScheduler; this.clientReleaseManager = clientReleaseManager; this.dynamicConfigurationManager = dynamicConfigurationManager; @@ -332,15 +313,15 @@ public Response sendMessage(@ReadOnly @Auth final Optional throw new WebApplicationException(Status.FORBIDDEN); } - final Optional destination; + final Optional maybeDestination; if (!isSyncMessage) { - destination = accountsManager.getByServiceIdentifier(destinationIdentifier); + maybeDestination = accountsManager.getByServiceIdentifier(destinationIdentifier); } else { - destination = source.map(AuthenticatedDevice::getAccount); + maybeDestination = source.map(AuthenticatedDevice::getAccount); } final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam( - context, source, destination, Optional.of(destinationIdentifier)); + context, source, maybeDestination, Optional.of(destinationIdentifier)); final Optional reportSpamToken; switch (spamCheck) { case final SpamChecker.Spam spam: return spam.response(); @@ -376,11 +357,11 @@ public Response sendMessage(@ReadOnly @Auth final Optional // Stories will be checked by the client; we bypass access checks here for stories. } else if (groupSendToken != null) { checkGroupSendToken(List.of(destinationIdentifier.toLibsignal()), groupSendToken); - if (destination.isEmpty()) { + if (maybeDestination.isEmpty()) { throw new NotFoundException(); } } else { - OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, destination, + OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, maybeDestination, destinationIdentifier); } @@ -389,20 +370,20 @@ public Response sendMessage(@ReadOnly @Auth final Optional // We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify // we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from // these requests. - if (isStory && destination.isEmpty()) { + if (isStory && maybeDestination.isEmpty()) { return Response.ok(new SendMessageResponse(needsSync)).build(); } // if destination is empty we would either throw an exception in OptionalAccess.verify when isStory is false // or else return a 200 response when isStory is true. - assert destination.isPresent(); + final Account destination = maybeDestination.orElseThrow(); if (source.isPresent() && !isSyncMessage) { - checkMessageRateLimit(source.get(), destination.get(), userAgent); + checkMessageRateLimit(source.get(), destination, userAgent); } if (isStory) { - rateLimiters.getStoriesLimiter().validate(destination.get().getUuid()); + rateLimiters.getStoriesLimiter().validate(destination.getUuid()); } final Set excludedDeviceIds; @@ -413,15 +394,32 @@ public Response sendMessage(@ReadOnly @Auth final Optional excludedDeviceIds = Collections.emptySet(); } - DestinationDeviceValidator.validateCompleteDeviceList(destination.get(), - messages.messages().stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()), + final Map messagesByDeviceId = messages.messages().stream() + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> { + try { + return message.toEnvelope( + destinationIdentifier, + source.map(AuthenticatedDevice::getAccount).orElse(null), + source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null), + messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(), + isStory, + messages.urgent(), + reportSpamToken.orElse(null)); + } catch (final IllegalArgumentException e) { + logger.warn("Received bad envelope type {} from {}", message.type(), userAgent); + throw new BadRequestException(e); + } + })); + + DestinationDeviceValidator.validateCompleteDeviceList(destination, + messagesByDeviceId.keySet(), excludedDeviceIds); - DestinationDeviceValidator.validateRegistrationIds(destination.get(), + DestinationDeviceValidator.validateRegistrationIds(destination, messages.messages(), IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId, - destination.get().getPhoneNumberIdentifier().equals(destinationIdentifier.uuid())); + destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid())); final String authType; if (SENDER_TYPE_IDENTIFIED.equals(senderType)) { @@ -434,31 +432,15 @@ public Response sendMessage(@ReadOnly @Auth final Optional authType = AUTH_TYPE_ACCESS_KEY; } - final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), + messageSender.sendMessages(destination, messagesByDeviceId); + + Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())), Tag.of(SENDER_TYPE_TAG_NAME, senderType), Tag.of(AUTH_TYPE_TAG_NAME, authType), - Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name())); - - for (final IncomingMessage incomingMessage : messages.messages()) { - destination.get().getDevice(incomingMessage.destinationDeviceId()) - .ifPresent(destinationDevice -> { - Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); - sendIndividualMessage( - source, - destination.get(), - destinationDevice, - destinationIdentifier, - messages.timestamp(), - messages.online(), - isStory, - messages.urgent(), - incomingMessage, - userAgent, - reportSpamToken); - }); - } + Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name()))) + .increment(messagesByDeviceId.size()); return Response.ok(new SendMessageResponse(needsSync)).build(); } catch (final MismatchedDevicesException e) { @@ -481,34 +463,6 @@ public Response sendMessage(@ReadOnly @Auth final Optional } } - - /** - * Build mapping of service IDs to resolved accounts and device/registration IDs - */ - private Map buildRecipientMap( - SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) { - return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) - .switchIfEmpty(Flux.error(BadRequestException::new)) - .map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue())) - .flatMap( - t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1())) - .flatMap(Mono::justOrEmpty) - .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) - .map( - account -> - new MultiRecipientDeliveryData( - t.getT1(), - account, - t.getT2(), - t.getT2().getDevicesAndRegistrationIds().collect( - Collectors.toMap(Pair::first, Pair::second)))) - // IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device - .onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e), - MAX_FETCH_ACCOUNT_CONCURRENCY) - .collectMap(MultiRecipientDeliveryData::serviceIdentifier) - .block(); - } - @Timed @Path("/multi_recipient") @PUT @@ -565,6 +519,32 @@ public Response sendMultiRecipientMessage( throw new BadRequestException("Illegal timestamp"); } + if (multiRecipientMessage.getRecipients().isEmpty()) { + throw new BadRequestException("Recipient list is empty"); + } + + // Verify that the message isn't too large before performing more expensive validations + multiRecipientMessage.getRecipients().values().forEach(recipient -> + validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, userAgent)); + + // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same + // recipient + { + final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID]; + + for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { + Arrays.fill(usedDeviceIds, false); + + for (final byte deviceId : recipient.getDevices()) { + if (usedDeviceIds[deviceId]) { + throw new BadRequestException(); + } + + usedDeviceIds[deviceId] = true; + } + } + } + final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty(), Optional.empty()); if (spamCheck instanceof final SpamChecker.Spam spam) { return spam.response(); @@ -584,28 +564,43 @@ public Response sendMultiRecipientMessage( if (groupSendToken != null) { // Group send endorsements are checked before we even attempt to resolve any accounts, since // the lists of service IDs in the envelope are all that we need to check against - checkGroupSendToken( - multiRecipientMessage.getRecipients().keySet(), groupSendToken); + checkGroupSendToken(multiRecipientMessage.getRecipients().keySet(), groupSendToken); } - final Map recipients = buildRecipientMap(multiRecipientMessage, isStory); + // At this point, the caller has at least superficially provided the information needed to send a multi-recipient + // message. Attempt to resolve the destination service identifiers to Signal accounts. + final Map resolvedRecipients = + Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) + .flatMap(serviceIdAndRecipient -> { + final ServiceIdentifier serviceIdentifier = + ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey()); + + return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .flatMap(Mono::justOrEmpty) + .switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new)) + .map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account)); + }, MAX_FETCH_ACCOUNT_CONCURRENCY) + .collectMap(Tuple2::getT1, Tuple2::getT2) + .blockOptional() + .orElse(Collections.emptyMap()); // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above. // Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them if (groupSendToken == null && !isStory) { - checkAccessKeys(accessKeys, recipients.values()); + checkAccessKeys(accessKeys, multiRecipientMessage, resolvedRecipients); } + // We might filter out all the recipients of a story (if none exist). // In this case there is no error so we should just return 200 now. if (isStory) { - if (recipients.isEmpty()) { + if (resolvedRecipients.isEmpty()) { return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build(); } try { - CompletableFuture.allOf(recipients.values() + CompletableFuture.allOf(resolvedRecipients.values() .stream() - .map(recipient -> recipient.account().getUuid()) + .map(account -> account.getIdentifier(IdentityType.ACI)) .map(accountIdentifier -> rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture()) .toList() @@ -620,31 +615,42 @@ public Response sendMultiRecipientMessage( } } - Collection accountMismatchedDevices = new ArrayList<>(); - Collection accountStaleDevices = new ArrayList<>(); - recipients.values().forEach(recipient -> { - final Account account = recipient.account(); + final Collection accountMismatchedDevices = new ArrayList<>(); + final Collection accountStaleDevices = new ArrayList<>(); + + multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> { + if (!resolvedRecipients.containsKey(recipient)) { + // When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We + // can just skip them. + return; + } + + final Account account = resolvedRecipients.get(recipient); try { - DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), + final Map deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds() + .collect(Collectors.toMap(Pair::first, Pair::second)); + + DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(), Collections.emptySet()); DestinationDeviceValidator.validateRegistrationIds( account, - recipient.deviceIdToRegistrationId().entrySet(), + deviceIdsToRegistrationIds.entrySet(), Map.Entry::getKey, e -> Integer.valueOf(e.getValue()), - recipient.serviceIdentifier().identityType() == IdentityType.PNI); - } catch (MismatchedDevicesException e) { + serviceId instanceof ServiceId.Pni); + } catch (final MismatchedDevicesException e) { accountMismatchedDevices.add( new AccountMismatchedDevices( - recipient.serviceIdentifier(), + ServiceIdentifier.fromLibsignal(serviceId), new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); - } catch (StaleDevicesException e) { + } catch (final StaleDevicesException e) { accountStaleDevices.add( - new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); + new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices()))); } }); + if (!accountMismatchedDevices.isEmpty()) { return Response .status(409) @@ -670,39 +676,30 @@ public Response sendMultiRecipientMessage( } try { - final byte[] sharedMrmKey = messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage); - - CompletableFuture.allOf( - recipients.values().stream() - .flatMap(recipientData -> { - final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of( - UserAgentTagUtil.getPlatformTag(userAgent), - Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI), - Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), - Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED), - Tag.of(AUTH_TYPE_TAG_NAME, authType), - Tag.of(IDENTITY_TYPE_TAG_NAME, recipientData.serviceIdentifier().identityType().name()))); - - validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipientData.recipient()), true, userAgent); - - return recipientData.deviceIdToRegistrationId().keySet().stream().map( - deviceId -> CompletableFuture.runAsync( - () -> { - final Account destinationAccount = recipientData.account(); - final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient()); - - // we asserted this must exist in validateCompleteDeviceList - final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow(); - - sentMessageCounter.increment(); - sendCommonPayloadMessage( - destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, - online, isStory, isUrgent, payload, sharedMrmKey); - }, - multiRecipientMessageExecutor)); - }) - .toArray(CompletableFuture[]::new)) - .get(); + messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, timestamp, isStory, online, isUrgent).get(); + + multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> { + if (!resolvedRecipients.containsKey(recipient)) { + // We skipped sending to this recipient because we're sending a story and couldn't resolve the recipient to + // an existing account; don't increment the counter for this recipient. + return; + } + + final String identityType = switch (serviceId) { + case ServiceId.Aci ignored -> "ACI"; + case ServiceId.Pni ignored -> "PNI"; + default -> "unknown"; + }; + + Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of( + UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI), + Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), + Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED), + Tag.of(AUTH_TYPE_TAG_NAME, authType), + Tag.of(IDENTITY_TYPE_TAG_NAME, identityType))) + .increment(recipient.getDevices().length); + }); } catch (InterruptedException e) { logger.error("interrupted while delivering multi-recipient messages", e); throw new InternalServerErrorException("interrupted during delivery"); @@ -729,29 +726,21 @@ private void checkGroupSendToken( private void checkAccessKeys( final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys, - final Collection destinations) { - final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH; + final SealedSenderMultiRecipientMessage multiRecipientMessage, + final Map resolvedRecipients) { + + if (multiRecipientMessage.getRecipients().keySet().stream() + .anyMatch(serviceId -> serviceId instanceof ServiceId.Pni)) { - if (destinations.stream() - .anyMatch(destination -> IdentityType.PNI.equals(destination.serviceIdentifier.identityType()))) { throw new WebApplicationException("Multi-recipient messages must be addressed to ACI service IDs", Status.UNAUTHORIZED); } - final byte[] combinedUnidentifiedAccessKeys = destinations.stream() - .map(MultiRecipientDeliveryData::account) - .filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) - .map(account -> - account.getUnidentifiedAccessKey() - .filter(b -> b.length == keyLength) - .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED))) - .reduce(new byte[keyLength], - (a, b) -> { - final byte[] xor = new byte[keyLength]; - IntStream.range(0, keyLength).forEach(i -> xor[i] = (byte) (a[i] ^ b[i])); - return xor; - }); - if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) { + try { + if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(resolvedRecipients.values(), accessKeys.getAccessKeys())) { + throw new WebApplicationException(Status.UNAUTHORIZED); + } + } catch (final IllegalArgumentException ignored) { throw new WebApplicationException(Status.UNAUTHORIZED); } } @@ -912,65 +901,6 @@ public Response reportSpamMessage( .build(); } - private void sendIndividualMessage( - Optional source, - Account destinationAccount, - Device destinationDevice, - ServiceIdentifier destinationIdentifier, - long timestamp, - boolean online, - boolean story, - boolean urgent, - IncomingMessage incomingMessage, - String userAgentString, - Optional spamReportToken) { - - final Envelope envelope; - - try { - final Account sourceAccount = source.map(AuthenticatedDevice::getAccount).orElse(null); - final Byte sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); - envelope = incomingMessage.toEnvelope( - destinationIdentifier, - sourceAccount, - sourceDeviceId, - timestamp == 0 ? System.currentTimeMillis() : timestamp, - story, - urgent, - spamReportToken.orElse(null)); - } catch (final IllegalArgumentException e) { - logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); - throw new BadRequestException(e); - } - - messageSender.sendMessage(destinationAccount, destinationDevice, envelope, online); - } - - private void sendCommonPayloadMessage(Account destinationAccount, - Device destinationDevice, - ServiceIdentifier serviceIdentifier, - long timestamp, - boolean online, - boolean story, - boolean urgent, - byte[] payload, - byte[] sharedMrmKey) { - - final Envelope.Builder messageBuilder = Envelope.newBuilder(); - final long serverTimestamp = System.currentTimeMillis(); - - messageBuilder - .setType(Type.UNIDENTIFIED_SENDER) - .setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp) - .setServerTimestamp(serverTimestamp) - .setStory(story) - .setUrgent(urgent) - .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) - .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)); - - messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); - } - private void checkMessageRateLimit(AuthenticatedDevice source, Account destination, String userAgent) throws RateLimitExceededException { final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); @@ -1020,15 +950,4 @@ private void validateEnvelopeType(final int type, final String userAgent) { throw new BadRequestException("reserved envelope type"); } } - - public static Optional getMessageContent(IncomingMessage message) { - if (StringUtils.isEmpty(message.content())) return Optional.empty(); - - try { - return Optional.of(Base64.getDecoder().decode(message.content())); - } catch (IllegalArgumentException e) { - logger.debug("Bad B64", e); - return Optional.empty(); - } - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index 5e99a5a09..b18b08f35 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -55,11 +55,7 @@ public SealedSenderMultiRecipientMessage readFrom(Class message.messageSizeForRecipient(r) > MAX_MESSAGE_SIZE)) { - throw new BadRequestException("message payload too large"); - } + RECIPIENT_COUNT_DISTRIBUTION.record(message.getRecipients().size()); return message; } catch (InvalidMessageException | InvalidVersionException e) { throw new BadRequestException(e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index 3c87479b3..5cd249a3a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -9,9 +9,14 @@ import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Metrics; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; +import org.whispersystems.textsecuregcm.util.Util; /** * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, @@ -42,26 +47,82 @@ public MessageSender(final MessagesManager messagesManager, final PushNotificati this.pushNotificationManager = pushNotificationManager; } - public void sendMessage(final Account account, final Device device, final Envelope message, final boolean online) { - final boolean destinationPresent = messagesManager.insert(account.getUuid(), - device.getId(), - online ? message.toBuilder().setEphemeral(true).build() : message); + /** + * Sends messages to devices associated with the given destination account. If a destination device has a valid push + * notification token and does not have an active connection to a Signal server, then this method will also send a + * push notification to that device to announce the availability of new messages. + * + * @param account the account to which to send messages + * @param messagesByDeviceId a map of device IDs to message payloads + */ + public void sendMessages(final Account account, final Map messagesByDeviceId) { + messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId) + .forEach((deviceId, destinationPresent) -> { + final Envelope message = messagesByDeviceId.get(deviceId); - if (!destinationPresent && !online) { - try { - pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent()); - } catch (final NotPushRegisteredException ignored) { - } - } + if (!destinationPresent && !message.getEphemeral()) { + try { + pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent()); + } catch (final NotPushRegisteredException ignored) { + } + } + + Metrics.counter(SEND_COUNTER_NAME, + CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"), + EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()), + CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent), + URGENT_TAG_NAME, String.valueOf(message.getUrgent()), + STORY_TAG_NAME, String.valueOf(message.getStory()), + SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) + .increment(); + }); + } + + /** + * Sends messages to a group of recipients. If a destination device has a valid push notification token and does not + * have an active connection to a Signal server, then this method will also send a push notification to that device to + * announce the availability of new messages. + * + * @param multiRecipientMessage the multi-recipient message to send to the given recipients + * @param resolvedRecipients a map of recipients to resolved Signal accounts + * @param clientTimestamp the time at which the sender reports the message was sent + * @param isStory {@code true} if the message is a story or {@code false otherwise} + * @param isEphemeral {@code true} if the message should only be delivered to devices with active connections or + * {@code false otherwise} + * @param isUrgent {@code true} if the message is urgent or {@code false otherwise} + * + * @return a future that completes when all messages have been inserted into delivery queues + */ + public CompletableFuture sendMultiRecipientMessage(final SealedSenderMultiRecipientMessage multiRecipientMessage, + final Map resolvedRecipients, + final long clientTimestamp, + final boolean isStory, + final boolean isEphemeral, + final boolean isUrgent) { + + return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, + isStory, isEphemeral, isUrgent) + .thenAccept(clientPresenceByAccountAndDevice -> + clientPresenceByAccountAndDevice.forEach((account, clientPresenceByDeviceId) -> + clientPresenceByDeviceId.forEach((deviceId, clientPresent) -> { + if (!clientPresent && !isEphemeral) { + try { + pushNotificationManager.sendNewMessageNotification(account, deviceId, isUrgent); + } catch (final NotPushRegisteredException ignored) { + } + } - Metrics.counter(SEND_COUNTER_NAME, - CHANNEL_TAG_NAME, getDeliveryChannelName(device), - EPHEMERAL_TAG_NAME, String.valueOf(online), - CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent), - URGENT_TAG_NAME, String.valueOf(message.getUrgent()), - STORY_TAG_NAME, String.valueOf(message.getStory()), - SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) - .increment(); + Metrics.counter(SEND_COUNTER_NAME, + CHANNEL_TAG_NAME, + account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"), + EPHEMERAL_TAG_NAME, String.valueOf(isEphemeral), + CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent), + URGENT_TAG_NAME, String.valueOf(isUrgent), + STORY_TAG_NAME, String.valueOf(isStory), + SEALED_SENDER_TAG_NAME, String.valueOf(true)) + .increment(); + }))) + .thenRun(Util.NOOP); } @VisibleForTesting diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index 74f623518..09da97e82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -8,6 +8,7 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -43,21 +44,21 @@ public void sendReceipt(ServiceIdentifier sourceIdentifier, byte sourceDeviceId, try { accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse( destinationAccount -> { - final Envelope.Builder message = Envelope.newBuilder() + final Envelope message = Envelope.newBuilder() .setServerTimestamp(System.currentTimeMillis()) .setSourceServiceId(sourceIdentifier.toServiceIdentifierString()) .setSourceDevice(sourceDeviceId) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) .setClientTimestamp(messageId) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) - .setUrgent(false); + .setUrgent(false) + .build(); - for (final Device destinationDevice : destinationAccount.getDevices()) { - try { - messageSender.sendMessage(destinationAccount, destinationDevice, message.build(), false); - } catch (final Exception e) { - logger.warn("Could not send delivery receipt", e); - } + try { + messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream() + .collect(Collectors.toMap(Device::getId, ignored -> message))); + } catch (final Exception e) { + logger.warn("Could not send delivery receipt", e); } }, () -> logger.info("No longer registered: {}", destinationIdentifier) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 9119b88fe..d0e7da979 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -4,8 +4,8 @@ */ package org.whispersystems.textsecuregcm.storage; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; @@ -13,10 +13,10 @@ import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; +import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; @@ -115,40 +115,39 @@ private void validateDeviceMessages(final Account account, private void sendDeviceMessages(final Account account, final List deviceMessages) { try { - deviceMessages.forEach(message -> - sendMessageToSelf(account, account.getDevice(message.destinationDeviceId()), message)); - } catch (RuntimeException e) { + final long serverTimestamp = System.currentTimeMillis(); + + messageSender.sendMessages(account, deviceMessages.stream() + .filter(message -> getMessageContent(message).isPresent()) + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder() + .setType(Envelope.Type.forNumber(message.type())) + .setClientTimestamp(serverTimestamp) + .setServerTimestamp(serverTimestamp) + .setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) + .setContent(ByteString.copyFrom(getMessageContent(message).orElseThrow())) + .setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) + .setSourceDevice(Device.PRIMARY_ID) + .setUpdatedPni(account.getPhoneNumberIdentifier().toString()) + .setUrgent(true) + .setEphemeral(false) + .build()))); + } catch (final RuntimeException e) { logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); throw e; } } - @VisibleForTesting - void sendMessageToSelf( - Account sourceAndDestinationAccount, Optional destinationDevice, IncomingMessage message) { - Optional contents = MessageController.getMessageContent(message); - if (contents.isEmpty()) { - logger.debug("empty message contents sending to self, ignoring"); - return; - } else if (destinationDevice.isEmpty()) { - logger.debug("destination device not present"); - return; + private static Optional getMessageContent(final IncomingMessage message) { + if (StringUtils.isEmpty(message.content())) { + logger.warn("Message has no content"); + return Optional.empty(); } - final long serverTimestamp = System.currentTimeMillis(); - final Envelope envelope = Envelope.newBuilder() - .setType(Envelope.Type.forNumber(message.type())) - .setClientTimestamp(serverTimestamp) - .setServerTimestamp(serverTimestamp) - .setDestinationServiceId( - new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) - .setContent(ByteString.copyFrom(contents.get())) - .setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) - .setSourceDevice(Device.PRIMARY_ID) - .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) - .setUrgent(true) - .build(); - - messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false); + try { + return Optional.of(Base64.getDecoder().decode(message.content())); + } catch (final IllegalArgumentException e) { + logger.warn("Failed to parse message content", e); + return Optional.empty(); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index a612e3c21..fb78e0cd6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -203,22 +203,28 @@ public MessagesCache(final FaultTolerantRedisClusterClient redisCluster, this.unlockQueueScript = unlockQueueScript; } - public boolean insert(final UUID messageGuid, + public CompletableFuture insert(final UUID messageGuid, final UUID destinationAccountIdentifier, final byte destinationDeviceId, final MessageProtos.Envelope message) { final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build(); - return insertTimer.record(() -> insertScript.execute(destinationAccountIdentifier, destinationDeviceId, messageWithGuid)); + final Timer.Sample sample = Timer.start(); + + return insertScript.executeAsync(destinationAccountIdentifier, destinationDeviceId, messageWithGuid) + .whenComplete((ignored, throwable) -> sample.stop(insertTimer)); } - public byte[] insertSharedMultiRecipientMessagePayload( + public CompletableFuture insertSharedMultiRecipientMessagePayload( final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { - return insertSharedMrmPayloadTimer.record(() -> { - final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID()); - insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage); - return sharedMrmKey; - }); + + final Timer.Sample sample = Timer.start(); + + final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID()); + + return insertMrmScript.executeAsync(sharedMrmKey, sealedSenderMultiRecipientMessage) + .thenApply(ignored -> sharedMrmKey) + .whenComplete((ignored, throwable) -> sample.stop(insertSharedMrmPayloadTimer)); } public CompletableFuture> remove(final UUID destinationUuid, final byte destinationDevice, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java index 784d2a264..7e302900d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScript.java @@ -12,6 +12,7 @@ import java.util.Arrays; import java.util.List; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.push.ClientEvent; import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent; @@ -44,7 +45,7 @@ class MessagesCacheInsertScript { * @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false} * otherwise */ - boolean execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { + CompletableFuture executeAsync(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { assert envelope.hasServerGuid(); assert envelope.hasServerTimestamp(); @@ -62,6 +63,7 @@ boolean execute(final UUID destinationUuid, final byte destinationDevice, final NEW_MESSAGE_EVENT_BYTES // eventPayload )); - return (boolean) insertScript.executeBinary(keys, args); + return insertScript.executeBinaryAsync(keys, args) + .thenApply(result -> (boolean) result); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java index ce05af209..b8da00bd7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.java @@ -9,9 +9,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.util.Util; /** * Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as @@ -31,7 +33,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript { ScriptOutputType.INTEGER); } - void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) { + CompletableFuture executeAsync(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) { final List keys = List.of( sharedMrmKey // sharedMrmKey ); @@ -47,6 +49,7 @@ void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage } }); - script.executeBinary(keys, args); + return script.executeBinaryAsync(keys, args) + .thenRun(Util.NOOP); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 838e22a21..c272fc895 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -6,18 +6,23 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; +import com.google.protobuf.ByteString; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; +import java.time.Clock; import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; +import java.util.stream.IntStream; import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; @@ -25,6 +30,8 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.Pair; import reactor.core.observability.micrometer.Micrometer; @@ -48,41 +55,120 @@ public class MessagesManager { private final MessagesCache messagesCache; private final ReportMessageManager reportMessageManager; private final ExecutorService messageDeletionExecutor; + private final Clock clock; public MessagesManager( final MessagesDynamoDb messagesDynamoDb, final MessagesCache messagesCache, final ReportMessageManager reportMessageManager, - final ExecutorService messageDeletionExecutor) { + final ExecutorService messageDeletionExecutor, + final Clock clock) { + this.messagesDynamoDb = messagesDynamoDb; this.messagesCache = messagesCache; this.reportMessageManager = reportMessageManager; this.messageDeletionExecutor = messageDeletionExecutor; + this.clock = clock; } /** - * Inserts a message into a target device's message queue and notifies registered listeners that a new message is - * available. + * Inserts messages into the message queues for devices associated with the identified account. * - * @param destinationUuid the account identifier for the destination queue - * @param destinationDeviceId the device ID for the destination queue - * @param message the message to insert into the queue + * @param accountIdentifier the account identifier for the destination queue + * @param messagesByDeviceId a map of device IDs to messages * - * @return {@code true} if the destination device is "present" (i.e. has an active event listener) or {@code false} - * otherwise + * @return a map of device IDs to a device's presence state (i.e. if the device has an active event listener) * * @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager */ - public boolean insert(final UUID destinationUuid, final byte destinationDeviceId, final Envelope message) { - final UUID messageGuid = UUID.randomUUID(); - - final boolean destinationPresent = messagesCache.insert(messageGuid, destinationUuid, destinationDeviceId, message); + public Map insert(final UUID accountIdentifier, final Map messagesByDeviceId) { + return insertAsync(accountIdentifier, messagesByDeviceId).join(); + } - if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) { - reportMessageManager.store(message.getSourceServiceId(), messageGuid); - } + private CompletableFuture> insertAsync(final UUID accountIdentifier, final Map messagesByDeviceId) { + final Map devicePresenceById = new ConcurrentHashMap<>(); + + return CompletableFuture.allOf(messagesByDeviceId.entrySet().stream() + .map(deviceIdAndMessage -> { + final byte deviceId = deviceIdAndMessage.getKey(); + final Envelope message = deviceIdAndMessage.getValue(); + final UUID messageGuid = UUID.randomUUID(); + + return messagesCache.insert(messageGuid, accountIdentifier, deviceId, message) + .thenAccept(present -> { + if (message.hasSourceServiceId() && !accountIdentifier.toString() + .equals(message.getSourceServiceId())) { + // Note that this is an asynchronous, best-effort, fire-and-forget operation + reportMessageManager.store(message.getSourceServiceId(), messageGuid); + } + + devicePresenceById.put(deviceId, present); + }); + }) + .toArray(CompletableFuture[]::new)) + .thenApply(ignored -> devicePresenceById); + } - return destinationPresent; + /** + * Inserts messages into the message queues for devices associated with the identified accounts. + * + * @param multiRecipientMessage the multi-recipient message to insert into destination queues + * @param resolvedRecipients a map of multi-recipient message {@code Recipient} entities to their corresponding + * Signal accounts; messages will not be delivered to unresolved recipients + * @param clientTimestamp the timestamp for the message as reported by the sending party + * @param isStory {@code true} if the given message is a story or {@code false} otherwise + * @param isEphemeral {@code true} if the given message should only be delivered to devices with active + * connections to a Signal server or {@code false} otherwise + * @param isUrgent {@code true} if the given message is urgent or {@code false} otherwise + * + * @return a map of accounts to maps of device IDs to a device's presence state (i.e. if the device has an active + * event listener) + * + * @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager + */ + public CompletableFuture>> insertMultiRecipientMessage( + final SealedSenderMultiRecipientMessage multiRecipientMessage, + final Map resolvedRecipients, + final long clientTimestamp, + final boolean isStory, + final boolean isEphemeral, + final boolean isUrgent) { + + final long serverTimestamp = clock.millis(); + + return insertSharedMultiRecipientMessagePayload(multiRecipientMessage) + .thenCompose(sharedMrmKey -> { + final Envelope prototypeMessage = Envelope.newBuilder() + .setType(Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp) + .setServerTimestamp(serverTimestamp) + .setStory(isStory) + .setEphemeral(isEphemeral) + .setUrgent(isUrgent) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) + .build(); + + final Map> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>(); + + return CompletableFuture.allOf(multiRecipientMessage.getRecipients().entrySet().stream() + .filter(serviceIdAndRecipient -> resolvedRecipients.containsKey(serviceIdAndRecipient.getValue())) + .map(serviceIdAndRecipient -> { + final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey()); + final SealedSenderMultiRecipientMessage.Recipient recipient = serviceIdAndRecipient.getValue(); + final byte[] devices = recipient.getDevices(); + + return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI), + IntStream.range(0, devices.length).mapToObj(i -> devices[i]) + .collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder() + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .build()))) + .thenAccept(clientPresenceByDeviceId -> + clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient), + clientPresenceByDeviceId)); + }) + .toArray(CompletableFuture[]::new)) + .thenApply(ignored -> clientPresenceByAccountAndDevice); + }); } public CompletableFuture mayHavePersistedMessages(final UUID destinationUuid, final Device destinationDevice) { @@ -217,7 +303,7 @@ public CompletableFuture> getEarliestUndeliveredTimestampForDe * @return a key where the shared data is stored * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript */ - public byte[] insertSharedMultiRecipientMessagePayload( + private CompletableFuture insertSharedMultiRecipientMessagePayload( final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDb.java index 49111f991..b5185201f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDb.java @@ -3,6 +3,8 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.Util; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; @@ -11,6 +13,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Map; +import java.util.concurrent.CompletableFuture; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; @@ -20,6 +23,7 @@ public class ReportMessageDynamoDb { static final String ATTR_TTL = "E"; private final DynamoDbClient db; + private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; private final Duration ttl; @@ -30,20 +34,26 @@ public class ReportMessageDynamoDb { .distributionStatisticExpiry(Duration.ofDays(1)) .register(Metrics.globalRegistry); - public ReportMessageDynamoDb(final DynamoDbClient dynamoDB, final String tableName, final Duration ttl) { + public ReportMessageDynamoDb(final DynamoDbClient dynamoDB, + final DynamoDbAsyncClient dynamoDbAsyncClient, + final String tableName, + final Duration ttl) { + this.db = dynamoDB; + this.dynamoDbAsyncClient = dynamoDbAsyncClient; this.tableName = tableName; this.ttl = ttl; } - public void store(byte[] hash) { - db.putItem(PutItemRequest.builder() + public CompletableFuture store(byte[] hash) { + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() .tableName(tableName) .item(Map.of( KEY_HASH, AttributeValues.fromByteArray(hash), ATTR_TTL, AttributeValues.fromLong(Instant.now().plus(ttl).getEpochSecond()) )) - .build()); + .build()) + .thenRun(Util.NOOP); } public boolean remove(byte[] hash) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java index 0123a2617..9e2cadcc9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java @@ -54,11 +54,8 @@ public void addListener(final ReportedMessageListener listener) { } public void store(String sourceAci, UUID messageGuid) { - try { - Objects.requireNonNull(sourceAci); - - reportMessageDynamoDb.store(hash(messageGuid, sourceAci)); + reportMessageDynamoDb.store(hash(messageGuid, Objects.requireNonNull(sourceAci))); } catch (final Exception e) { logger.warn("Failed to store hash", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java index 3e6524471..4ba72f4b9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java @@ -22,13 +22,15 @@ public class DestinationDeviceValidator { /** * @see #validateRegistrationIds(Account, Stream, boolean) */ - public static void validateRegistrationIds(final Account account, final Collection messages, - Function getDeviceId, Function getRegistrationId, boolean usePhoneNumberIdentity) - throws StaleDevicesException { + public static void validateRegistrationIds(final Account account, + final Collection messages, + Function getDeviceId, + Function getRegistrationId, + boolean usePhoneNumberIdentity) throws StaleDevicesException { + validateRegistrationIds(account, messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))), usePhoneNumberIdentity); - } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index b25f8abe1..07cabb86e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -217,13 +217,13 @@ static CommandDependencies build( MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); - ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, + ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getReportMessage().getTableName(), configuration.getReportMessageConfiguration().getReportTtl()); ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, configuration.getReportMessageConfiguration().getCounterTtl()); MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, - reportMessageManager, messageDeletionExecutor); + reportMessageManager, messageDeletionExecutor, Clock.systemUTC()); AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName()); ClientPublicKeysManager clientPublicKeysManager = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index f570efc1d..1ccdf2334 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -9,29 +9,27 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.collection.IsEmptyCollection.empty; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; -import static org.whispersystems.textsecuregcm.util.MockUtils.exactly; import com.fasterxml.jackson.core.JsonProcessingException; -import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -42,14 +40,11 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; -import java.io.ByteArrayInputStream; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.time.Duration; import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneOffset; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -60,10 +55,6 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.stream.Stream; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; @@ -78,17 +69,12 @@ import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junitpioneer.jupiter.cartesian.ArgumentSets; -import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; -import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; -import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; -import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -96,10 +82,10 @@ import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; -import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; @@ -125,12 +111,14 @@ import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper; +import org.whispersystems.textsecuregcm.tests.util.TestRecipient; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestClock; -import org.whispersystems.textsecuregcm.util.UUIDUtil; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.websocket.WebsocketHeaders; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; @@ -186,7 +174,6 @@ class MessageControllerTest { private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final PushNotificationScheduler pushNotificationScheduler = mock(PushNotificationScheduler.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); - private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); @SuppressWarnings("unchecked") @@ -197,6 +184,8 @@ class MessageControllerTest { private static final TestClock clock = TestClock.now(); + private static final Instant START_OF_DAY = LocalDate.now(clock).atStartOfDay().toInstant(ZoneOffset.UTC); + private static final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) @@ -207,7 +196,7 @@ class MessageControllerTest { .addResource( new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler, - reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, mock(ClientReleaseManager.class), + reportMessageManager, messageDeliveryScheduler, mock(ClientReleaseManager.class), dynamicConfigurationManager, serverSecretParams, SpamChecker.noop(), new MessageMetrics(), mock(MessageDeliveryLoopMonitor.class), clock)) .build(); @@ -216,6 +205,9 @@ reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, m void setup() { reset(pushNotificationScheduler); + when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) + .thenReturn(CompletableFuture.completedFuture(null)); + final List singleDeviceList = List.of( generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, true) ); @@ -311,12 +303,15 @@ void testSingleDeviceCurrent() throws Exception { assertThat("Good Response", response.getStatus(), is(equalTo(200))); - ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(messageSender).sendMessages(any(), captor.capture()); - assertTrue(captor.getValue().hasSourceServiceId()); - assertTrue(captor.getValue().hasSourceDevice()); - assertTrue(captor.getValue().getUrgent()); + assertEquals(1, captor.getValue().size()); + final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); + + assertTrue(message.hasSourceServiceId()); + assertTrue(message.hasSourceDevice()); + assertTrue(message.getUrgent()); } } @@ -353,12 +348,15 @@ void testSingleDeviceCurrentNotUrgent() throws Exception { assertThat("Good Response", response.getStatus(), is(equalTo(200))); - ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(messageSender).sendMessages(any(), captor.capture()); + + assertEquals(1, captor.getValue().size()); + final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); - assertTrue(captor.getValue().hasSourceServiceId()); - assertTrue(captor.getValue().hasSourceDevice()); - assertFalse(captor.getValue().getUrgent()); + assertTrue(message.hasSourceServiceId()); + assertTrue(message.hasSourceDevice()); + assertFalse(message.getUrgent()); } } @@ -375,11 +373,14 @@ void testSingleDeviceCurrentByPni() throws Exception { assertThat("Good Response", response.getStatus(), is(equalTo(200))); - ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(messageSender).sendMessages(any(), captor.capture()); - assertTrue(captor.getValue().hasSourceServiceId()); - assertTrue(captor.getValue().hasSourceDevice()); + assertEquals(1, captor.getValue().size()); + final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); + + assertTrue(message.hasSourceServiceId()); + assertTrue(message.hasSourceDevice()); } } @@ -410,11 +411,14 @@ void testSingleDeviceCurrentUnidentified() throws Exception { assertThat("Good Response", response.getStatus(), is(equalTo(200))); - ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender, times(1)).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(messageSender).sendMessages(any(), captor.capture()); + + assertEquals(1, captor.getValue().size()); + final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); - assertFalse(captor.getValue().hasSourceServiceId()); - assertFalse(captor.getValue().hasSourceDevice()); + assertFalse(message.hasSourceServiceId()); + assertFalse(message.hasSourceDevice()); } } @@ -446,9 +450,14 @@ void testSingleDeviceCurrentGroupSendEndorsement( assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); if (expectedResponse == 200) { - verify(messageSender).sendMessage( - any(Account.class), any(Device.class), argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()), - eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + verify(messageSender).sendMessages(any(), captor.capture()); + + assertEquals(1, captor.getValue().size()); + final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); + + assertFalse(message.hasSourceServiceId()); + assertFalse(message.hasSourceDevice()); } else { verifyNoMoreInteractions(messageSender); } @@ -607,12 +616,16 @@ void testMultiDevice() throws Exception { assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(Envelope.class); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); - verify(messageSender, times(3)) - .sendMessage(any(Account.class), any(Device.class), envelopeCaptor.capture(), eq(false)); + assertEquals(3, envelopeCaptor.getValue().size()); - envelopeCaptor.getAllValues().forEach(envelope -> assertTrue(envelope.getUrgent())); + envelopeCaptor.getValue().values().forEach(envelope -> { + assertTrue(envelope.getUrgent()); + }); } } @@ -629,12 +642,16 @@ void testMultiDeviceNotUrgent() throws Exception { assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(Envelope.class); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); - verify(messageSender, times(3)) - .sendMessage(any(Account.class), any(Device.class), envelopeCaptor.capture(), eq(false)); + assertEquals(3, envelopeCaptor.getValue().size()); - envelopeCaptor.getAllValues().forEach(envelope -> assertFalse(envelope.getUrgent())); + envelopeCaptor.getValue().values().forEach(envelope -> { + assertFalse(envelope.getUrgent()); + }); } } @@ -651,8 +668,8 @@ void testMultiDeviceByPni() throws Exception { assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); - verify(messageSender, times(3)) - .sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); + verify(messageSender).sendMessages(any(Account.class), + argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3)); } } @@ -1084,7 +1101,7 @@ private static Stream testReportMessageByAciWithNullSpamReportToken() } @Test - void testValidateContentLength() throws Exception { + void testValidateContentLength() { final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1); final byte[] contentBytes = new byte[contentLength]; Arrays.fill(contentBytes, (byte) 1); @@ -1101,8 +1118,7 @@ void testValidateContentLength() throws Exception { assertThat("Bad response", response.getStatus(), is(equalTo(413))); - verify(messageSender, never()).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), - anyBoolean()); + verify(messageSender, never()).sendMessages(any(), any()); } } @@ -1120,12 +1136,10 @@ void testValidateEnvelopeType(String payloadFilename, boolean expectOk) throws E if (expectOk) { assertEquals(200, response.getStatus()); - - final ArgumentCaptor captor = ArgumentCaptor.forClass(Envelope.class); - verify(messageSender).sendMessage(any(Account.class), any(Device.class), captor.capture(), eq(false)); + verify(messageSender).sendMessages(any(), any()); } else { assertEquals(400, response.getStatus()); - verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); + verify(messageSender, never()).sendMessages(any(), any()); } } } @@ -1137,740 +1151,427 @@ private static Stream testValidateEnvelopeType() { ); } - private record Recipient(ServiceIdentifier uuid, - Byte[] deviceId, - Integer[] registrationId, - byte[] perRecipientKeyMaterial) { - - Recipient(ServiceIdentifier uuid, - byte deviceId, - int registrationId, - byte[] perRecipientKeyMaterial) { - this(uuid, new Byte[]{deviceId}, new Integer[]{registrationId}, perRecipientKeyMaterial); - } - } - - private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, - final boolean useExplicitIdentifier) { - if (useExplicitIdentifier) { - bb.put(r.uuid().toFixedWidthByteArray()); - } else { - bb.put(UUIDUtil.toBytes(r.uuid().uuid())); - } - - assert (r.deviceId.length == r.registrationId.length); - - for (int i = 0; i < r.deviceId.length; i++) { - final int hasMore = i == r.deviceId.length - 1 ? 0 : 0x8000; - bb.put(r.deviceId()[i]); // device id (1 byte) - bb.putShort((short) (r.registrationId()[i] | hasMore)); // registration id (2 bytes) - } - bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) - } - - private static InputStream initializeMultiPayload(final List recipients, final byte[] buffer, final boolean explicitIdentifiers) { - return initializeMultiPayload(recipients, buffer, explicitIdentifiers, 39); - } - - private static InputStream initializeMultiPayload(final List recipients, final byte[] buffer, final boolean explicitIdentifiers, final int payloadSize) { - // initialize a binary payload according to our wire format - ByteBuffer bb = ByteBuffer.wrap(buffer); - bb.order(ByteOrder.BIG_ENDIAN); - - // first write the header - bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte - - // count varint - writeVarint(bb, recipients.size()); - - recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers)); - - // now write the actual message body (empty for now) - assert(payloadSize >= 32); - writeVarint(bb, payloadSize); - bb.put(new byte[payloadSize]); - - // return the input stream - return new ByteArrayInputStream(buffer, 0, bb.position()); - } - - private static void writeVarint(ByteBuffer bb, long n) { - while (n >= 0x80) { - bb.put ((byte) (n & 0x7F | 0x80)); - n = n >> 7; - } - bb.put((byte) (n & 0x7F)); - } - - @Test - void testManyRecipientMessage() { - - when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) - .thenReturn(new byte[]{1}); - - final int nRecipients = 999; - final int devicesPerRecipient = 5; - final List recipients = new ArrayList<>(); - - for (int i = 0; i < nRecipients; i++) { - final List devices = - IntStream.range(1, devicesPerRecipient + 1) - .mapToObj( - d -> generateTestDevice( - (byte) d, 100 + d, 10 * d, true)) - .collect(Collectors.toList()); - final UUID aci = new UUID(0L, i); - final UUID pni = new UUID(1L, i); - final String e164 = String.format("+1408555%04d", i); - final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES); - - when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(aci))) - .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - - when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(pni))) - .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - - devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48]))); - } - - byte[] buffer = new byte[1048576]; - InputStream stream = initializeMultiPayload(recipients, buffer, true); - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - try (final Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("story", true) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .put(entity)) { - - assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(200))); - verify(messageSender, times(nRecipients * devicesPerRecipient)).sendMessage(any(), any(), any(), eq(true)); - } - } - - // see testMultiRecipientMessageNoPni and testMultiRecipientMessagePni below for actual invocations - private void testMultiRecipientMessage( - Map> destinations, boolean authorize, boolean isStory, boolean urgent, - boolean explicitIdentifier, int expectedStatus, int expectedMessagesSent) throws Exception { - - when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) - .thenReturn(new byte[]{1}); - - final List recipients = new ArrayList<>(); - destinations.forEach( - (serviceIdentifier, deviceToRegistrationId) -> - deviceToRegistrationId.forEach( - (deviceId, registrationId) -> - recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48])))); - - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier); - - // set up the entity to use in our PUT request - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - - // build correct or incorrect access header - final String accessHeader; - if (authorize) { - final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); - accessHeader = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]); - } else { - accessHeader = "BBBBBBBBBBBBBBBBBBBBBB=="; - } - - // make the PUT request - try (final Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", isStory) - .queryParam("urgent", urgent) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader) - .put(entity)) { - - assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); - verify(messageSender, - exactly(expectedMessagesSent)) - .sendMessage( - any(), - any(), - argThat(env -> env.getUrgent() == urgent && !env.hasSourceServiceId() && !env.hasSourceDevice()), - eq(true)); - if (expectedStatus == 200) { - SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); - assertThat(smrmr.uuids404(), is(empty())); - } - } - } - - @SafeVarargs - private static Map submap(Map map, K... keys) { - return Arrays.stream(keys).collect(Collectors.toMap(Function.identity(), map::get)); - } - - private static Map> multiRecipientTargetMap() { - return Map.of( - SINGLE_DEVICE_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - SINGLE_DEVICE_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), - MULTI_DEVICE_ACI_ID, - Map.of( - MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, - MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, - MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3), - MULTI_DEVICE_PNI_ID, - Map.of( - MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1, - MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2, - MULTI_DEVICE_ID3, MULTI_DEVICE_PNI_REG_ID3), - NONEXISTENT_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - NONEXISTENT_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) - ); - } - - private record MultiRecipientMessageTestCase( - Map> destinations, - boolean authenticated, - boolean story, - int expectedStatus, - int expectedSentMessages) { - } - - @CartesianTest - @CartesianTest.MethodFactory("testMultiRecipientMessageNoPni") - void testMultiRecipientMessageNoPni(MultiRecipientMessageTestCase testCase, boolean urgent , boolean explicitIdentifier) throws Exception { - testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, explicitIdentifier, testCase.expectedStatus(), testCase.expectedSentMessages()); - } - - @SuppressWarnings("unused") - private static ArgumentSets testMultiRecipientMessageNoPni() { - final Map> targets = multiRecipientTargetMap(); - final Map> singleDeviceAci = submap(targets, SINGLE_DEVICE_ACI_ID); - final Map> multiDeviceAci = submap(targets, MULTI_DEVICE_ACI_ID); - final Map> bothAccountsAci = - submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID); - final Map> realAndFakeAci = - submap( - targets, - SINGLE_DEVICE_ACI_ID, - MULTI_DEVICE_ACI_ID, - NONEXISTENT_ACI_ID); - - final boolean auth = true; - final boolean unauth = false; - final boolean story = true; - final boolean notStory = false; - - return ArgumentSets - .argumentsForFirstParameter( - new MultiRecipientMessageTestCase(singleDeviceAci, unauth, story, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, unauth, story, 200, 3), - new MultiRecipientMessageTestCase(bothAccountsAci, unauth, story, 200, 4), - new MultiRecipientMessageTestCase(realAndFakeAci, unauth, story, 200, 4), - - new MultiRecipientMessageTestCase(singleDeviceAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(multiDeviceAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(bothAccountsAci, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(realAndFakeAci, unauth, notStory, 404, 0), - - new MultiRecipientMessageTestCase(singleDeviceAci, auth, story, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, auth, story, 200, 3), - new MultiRecipientMessageTestCase(bothAccountsAci, auth, story, 200, 4), - new MultiRecipientMessageTestCase(realAndFakeAci, auth, story, 200, 4), - - new MultiRecipientMessageTestCase(singleDeviceAci, auth, notStory, 200, 1), - new MultiRecipientMessageTestCase(multiDeviceAci, auth, notStory, 200, 3), - new MultiRecipientMessageTestCase(bothAccountsAci, auth, notStory, 200, 4), - new MultiRecipientMessageTestCase(realAndFakeAci, auth, notStory, 404, 0)) - .argumentsForNextParameter(false, true) // urgent - .argumentsForNextParameter(false, true); // explicitIdentifiers - } - - @CartesianTest - @CartesianTest.MethodFactory("testMultiRecipientMessagePni") - void testMultiRecipientMessagePni(MultiRecipientMessageTestCase testCase, boolean urgent) throws Exception { - testMultiRecipientMessage(testCase.destinations(), testCase.authenticated(), testCase.story(), urgent, true, testCase.expectedStatus(), testCase.expectedSentMessages()); - } - - @SuppressWarnings("unused") - private static ArgumentSets testMultiRecipientMessagePni() { - final Map> targets = multiRecipientTargetMap(); - final Map> singleDevicePni = submap(targets, SINGLE_DEVICE_PNI_ID); - final Map> singleDeviceAciAndPni = submap( - targets, SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_PNI_ID); - final Map> multiDevicePni = submap(targets, MULTI_DEVICE_PNI_ID); - final Map> bothAccountsMixed = - submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_PNI_ID); - final Map> realAndFakeMixed = - submap( - targets, - SINGLE_DEVICE_PNI_ID, - MULTI_DEVICE_ACI_ID, - NONEXISTENT_PNI_ID); - - final boolean auth = true; - final boolean unauth = false; - final boolean story = true; - final boolean notStory = false; - - return ArgumentSets - .argumentsForFirstParameter( - new MultiRecipientMessageTestCase(singleDevicePni, unauth, story, 200, 1), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, story, 200, 2), - new MultiRecipientMessageTestCase(multiDevicePni, unauth, story, 200, 3), - new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, story, 200, 4), - new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, story, 200, 4), - - new MultiRecipientMessageTestCase(singleDevicePni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(multiDevicePni, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(bothAccountsMixed, unauth, notStory, 401, 0), - new MultiRecipientMessageTestCase(realAndFakeMixed, unauth, notStory, 404, 0), - - new MultiRecipientMessageTestCase(singleDevicePni, auth, story, 200, 1), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, story, 200, 2), - new MultiRecipientMessageTestCase(multiDevicePni, auth, story, 200, 3), - new MultiRecipientMessageTestCase(bothAccountsMixed, auth, story, 200, 4), - new MultiRecipientMessageTestCase(realAndFakeMixed, auth, story, 200, 4), - - new MultiRecipientMessageTestCase(singleDevicePni, auth, notStory, 401, 0), - new MultiRecipientMessageTestCase(singleDeviceAciAndPni, auth, notStory, 401, 0), - new MultiRecipientMessageTestCase(multiDevicePni, auth, notStory, 401, 0), - new MultiRecipientMessageTestCase(bothAccountsMixed, auth, notStory, 401, 0), - new MultiRecipientMessageTestCase(realAndFakeMixed, auth, notStory, 404, 0)) - .argumentsForNextParameter(false, true); // urgent - } - - @Test - void testMultiRecipientMessageWithGroupSendEndorsements() throws Exception { - - when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) - .thenReturn(new byte[]{1}); - - final List recipients = List.of( - new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48])); - - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer, true); - - clock.pin(Instant.parse("2024-04-09T12:00:00.00Z")); - - try (final Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", false) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( - serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) - .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { - - assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); - verify(messageSender, - exactly(4)) - .sendMessage( - any(), - any(), - argThat(env -> !env.hasSourceServiceId() && !env.hasSourceDevice()), - eq(true)); - SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); - assertThat(smrmr.uuids404(), is(empty())); - } - } - - @Test - void testMultiRecipientMessageWithInvalidGroupSendEndorsements() throws Exception { - final List recipients = List.of( - new Recipient(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); - - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer, true); - - clock.pin(Instant.parse("2024-04-09T12:00:00.00Z")); - - try (final Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", false) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( - serverSecretParams, List.of(MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) - .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { - - assertThat("Unexpected response", response.getStatus(), is(equalTo(401))); - verifyNoMoreInteractions(messageSender); - } - } - - @Test - void testMultiRecipientMessageWithExpiredGroupSendEndorsements() throws Exception { - final List recipients = List.of( - new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); - - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer, true); - - clock.pin(Instant.parse("2024-04-10T12:00:00.00Z")); - - try (final Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", false) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.GROUP_SEND_TOKEN, AuthHelper.validGroupSendTokenHeader( - serverSecretParams, List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), Instant.parse("2024-04-10T00:00:00.00Z"))) - .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE))) { - - assertThat("Unexpected response", response.getStatus(), is(equalTo(401))); - verifyNoMoreInteractions(messageSender); - } - } - + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { - final List recipients = List.of( - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), - new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); - - Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", false) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) - .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE)); + @MethodSource + void sendMultiRecipientMessage(final Map accountsByServiceIdentifier, + final byte[] multiRecipientMessage, + final long timestamp, + final boolean isStory, + final boolean rateLimit, + final Optional maybeAccessKey, + final Optional maybeGroupSendToken, + final int expectedStatus, + final Set expectedResolvedAccounts) { - checkBadMultiRecipientResponse(response, 400); - } + clock.pin(START_OF_DAY); - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void testMultiRecipientSizeLimit() throws Exception { - final List recipients = List.of( - new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + when(accountsManager.getByServiceIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); - Response response = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", false) - .queryParam("urgent", false) - .request() - .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) - .put(Entity.entity(initializeMultiPayload(recipients, new byte[257<<10], true, 256<<10), MultiRecipientMessageProvider.MEDIA_TYPE)); + accountsByServiceIdentifier.forEach(((serviceIdentifier, account) -> + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))))); - checkBadMultiRecipientResponse(response, 400); - } + final boolean ephemeral = true; + final boolean urgent = false; - @ParameterizedTest - @CsvSource({ - "-1, 400", - "0, 200", - "1, 200", - "8640000000000000, 200", - "8640000000000001, 400", - - // 404 here is a weird quirk of controller pattern matching; this value doesn't get interpreted as `long`, and so - // it doesn't match the "send multi-recipient message" endpoint - "99999999999999999999999999999999999, 404" - }) - void testMultiRecipientExtremeTimestamp(final String timestamp, final int expectedStatus) { - - when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) - .thenReturn(new byte[]{1}); - - final int nRecipients = 999; - final int devicesPerRecipient = 5; - final List recipients = new ArrayList<>(); - - for (int i = 0; i < nRecipients; i++) { - final List devices = - IntStream.range(1, devicesPerRecipient + 1) - .mapToObj( - d -> generateTestDevice( - (byte) d, 100 + d, 10 * d, true)) - .collect(Collectors.toList()); - final UUID aci = new UUID(0L, i); - final UUID pni = new UUID(1L, i); - final String e164 = String.format("+1408555%04d", i); - final Account account = AccountsHelper.generateTestAccount(e164, aci, pni, devices, UNIDENTIFIED_ACCESS_BYTES); - - when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(aci))) - .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - - when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(pni))) - .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); - - devices.forEach(d -> recipients.add(new Recipient(new AciServiceIdentifier(aci), d.getId(), d.getRegistrationId(), new byte[48]))); - } - - byte[] buffer = new byte[1048576]; - InputStream stream = initializeMultiPayload(recipients, buffer, true); - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - try (final Response response = resources + final Invocation.Builder invocationBuilder = resources .getJerseyTest() .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("story", true) - .queryParam("urgent", false) .queryParam("ts", timestamp) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .put(entity)) { - - assertThat(response.readEntity(String.class), response.getStatus(), is(equalTo(expectedStatus))); - } - } - - @Test - void testSendStoryToUnknownAccount() throws Exception { - String accessBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES); - String json = jsonFixture("fixtures/current_message_single_device.json"); - UUID unknownUUID = UUID.randomUUID(); - IncomingMessageList list = SystemMapper.jsonMapper().readValue(json, IncomingMessageList.class); - - try (final Response response = - resources.getJerseyTest() - .target(String.format("/v1/messages/%s", unknownUUID)) - .queryParam("story", "true") - .request() - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes) - .put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE))) { - - assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200))); - } - } + .queryParam("online", ephemeral) + .queryParam("story", isStory) + .queryParam("urgent", urgent) + .request(); - @ParameterizedTest - @MethodSource - void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) { + maybeAccessKey.ifPresent(accessKey -> + invocationBuilder.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessKey)); - when(messagesManager.insertSharedMultiRecipientMessagePayload(any(SealedSenderMultiRecipientMessage.class))) - .thenReturn(new byte[]{1}); + maybeGroupSendToken.ifPresent(groupSendToken -> + invocationBuilder.header(HeaderUtils.GROUP_SEND_TOKEN, groupSendToken)); - final Recipient r1; - if (known) { - r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); + if (rateLimit) { + when(rateLimiter.validateAsync(any(UUID.class))) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77)))); } else { - r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]); + when(rateLimiter.validateAsync(any(UUID.class))) + .thenReturn(CompletableFuture.completedFuture(null)); } - Recipient r2 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); - Recipient r3 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); - Recipient r4 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48]); + try (final Response response = invocationBuilder + .put(Entity.entity(multiRecipientMessage, MultiRecipientMessageProvider.MEDIA_TYPE))) { - List recipients = List.of(r1, r2, r3, r4); - - byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer, useExplicitIdentifier); - // set up the entity to use in our PUT request - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - - // This looks weird, but there is a method to the madness. - // new bytes[16] is equivalent to UNIDENTIFIED_ACCESS_BYTES ^ UNIDENTIFIED_ACCESS_BYTES - // (i.e. we need to XOR all the access keys together) - String accessBytes = Base64.getEncoder().encodeToString(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - - // start building the request - Invocation.Builder bldr = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", true) - .queryParam("ts", 1663798405641L) - .queryParam("story", story) - .request() - .header(HttpHeaders.USER_AGENT, "Test User Agent") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes); + assertThat(response.getStatus(), is(equalTo(expectedStatus))); - // make the PUT request - try (final Response response = bldr.put(entity)) { - if (story || known) { - // it's a story so we unconditionally get 200 ok - assertEquals(200, response.getStatus()); + if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) { + verify(messageSender).sendMultiRecipientMessage(any(), + argThat(resolvedRecipients -> + new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)), + anyLong(), + eq(isStory), + eq(ephemeral), + eq(urgent)); } else { - // unknown recipient means 404 not found - assertEquals(404, response.getStatus()); + verify(messageSender, never()).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()); } } } - private static Stream testSendMultiRecipientMessageToUnknownAccounts() { - return Stream.of( - Arguments.of(true, true, false), - Arguments.of(true, false, false), - Arguments.of(false, true, false), - Arguments.of(false, false, false), - - Arguments.of(true, true, true), - Arguments.of(true, false, true), - Arguments.of(false, true, true), - Arguments.of(false, false, true) + private static List sendMultiRecipientMessage() throws Exception { + final UUID singleDeviceAccountAci = UUID.randomUUID(); + final UUID singleDeviceAccountPni = UUID.randomUUID(); + final UUID multiDeviceAccountAci = UUID.randomUUID(); + final UUID multiDeviceAccountPni = UUID.randomUUID(); + + final byte[] singleDeviceAccountUak = TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH); + final byte[] multiDeviceAccountUak = TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH); + + final int singleDevicePrimaryRegistrationId = 1; + final int multiDevicePrimaryRegistrationId = 2; + final int multiDeviceLinkedRegistrationId = 3; + + final Device singleDeviceAccountPrimary = mock(Device.class); + when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID); + when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId); + + final Device multiDeviceAccountPrimary = mock(Device.class); + when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID); + when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId); + + final Device multiDeviceAccountLinked = mock(Device.class); + when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1)); + when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId); + + final Account singleDeviceAccount = mock(Account.class); + when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci); + when(singleDeviceAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(singleDeviceAccountUak)); + when(singleDeviceAccount.getDevices()).thenReturn(List.of(singleDeviceAccountPrimary)); + when(singleDeviceAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + when(singleDeviceAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(singleDeviceAccountPrimary)); + + final Account multiDeviceAccount = mock(Account.class); + when(multiDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(multiDeviceAccountAci); + when(multiDeviceAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(multiDeviceAccountUak)); + when(multiDeviceAccount.getDevices()).thenReturn(List.of(multiDeviceAccountPrimary, multiDeviceAccountLinked)); + when(multiDeviceAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + when(multiDeviceAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(multiDeviceAccountPrimary)); + when(multiDeviceAccount.getDevice((byte) (Device.PRIMARY_ID + 1))).thenReturn(Optional.of(multiDeviceAccountLinked)); + + final String groupSendEndorsement = AuthHelper.validGroupSendTokenHeader(serverSecretParams, + List.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)), + START_OF_DAY.plus(Duration.ofDays(1))); + + final Map accountsByServiceIdentifier = Map.of( + new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount, + new AciServiceIdentifier(multiDeviceAccountAci), multiDeviceAccount, + new PniServiceIdentifier(singleDeviceAccountPni), singleDeviceAccount, + new PniServiceIdentifier(multiDeviceAccountPni), multiDeviceAccount); + + final byte[] aciMessage = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]))); + + return List.of( + Arguments.argumentSet("Multi-recipient story", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + true, + false, + Optional.empty(), + Optional.empty(), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Multi-recipient message with combined UAKs", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.of(Base64.getEncoder().encodeToString(UnidentifiedAccessUtil.getCombinedUnidentifiedAccessKey(List.of(singleDeviceAccount, multiDeviceAccount)))), + Optional.empty(), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Multi-recipient message with group send endorsement", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Incorrect combined UAK", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.of(Base64.getEncoder().encodeToString(TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH))), + Optional.empty(), + 401, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Incorrect group send endorsement", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(AuthHelper.validGroupSendTokenHeader(serverSecretParams, + List.of(new AciServiceIdentifier(UUID.randomUUID())), + START_OF_DAY.plus(Duration.ofDays(1)))), + 401, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + // Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if + // provided + Arguments.argumentSet("Story with combined UAKs", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + true, + false, + Optional.of(Base64.getEncoder().encodeToString(UnidentifiedAccessUtil.getCombinedUnidentifiedAccessKey(List.of(singleDeviceAccount, multiDeviceAccount)))), + Optional.empty(), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Story with group send endorsement", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + true, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Conflicting credentials", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.of(Base64.getEncoder().encodeToString(UnidentifiedAccessUtil.getCombinedUnidentifiedAccessKey(List.of(singleDeviceAccount, multiDeviceAccount)))), + Optional.of(groupSendEndorsement), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("No credentials", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.empty(), + 401, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Oversized payload", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48])), + MultiRecipientMessageProvider.MAX_MESSAGE_SIZE), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 413, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Negative timestamp", + accountsByServiceIdentifier, + aciMessage, + -1, + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Excessive timestamp", + accountsByServiceIdentifier, + aciMessage, + MessageController.MAX_TIMESTAMP + 1, + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Empty recipient list", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of()), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(AuthHelper.validGroupSendTokenHeader(serverSecretParams, + List.of(), + START_OF_DAY.plus(Duration.ofDays(1)))), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Story with empty recipient list", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of()), + clock.instant().toEpochMilli(), + true, + false, + Optional.empty(), + Optional.empty(), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Duplicate recipient", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 400, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Missing account", + Map.of(), + aciMessage, + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 404, + Collections.emptySet()), + + Arguments.argumentSet("Missing account for story", + Map.of(), + aciMessage, + clock.instant().toEpochMilli(), + true, + false, + Optional.empty(), + Optional.empty(), + 200, + Collections.emptySet()), + + Arguments.argumentSet("Missing device", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 409, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Extra device", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 2), multiDeviceLinkedRegistrationId + 1, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 409, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Stale registration ID", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new AciServiceIdentifier(singleDeviceAccountAci), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new AciServiceIdentifier(multiDeviceAccountAci), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId + 1, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(groupSendEndorsement), + 410, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Rate-limited story", + accountsByServiceIdentifier, + aciMessage, + clock.instant().toEpochMilli(), + true, + true, + Optional.empty(), + Optional.empty(), + 429, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Story to PNI recipients", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new PniServiceIdentifier(singleDeviceAccountPni), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]))), + clock.instant().toEpochMilli(), + true, + false, + Optional.empty(), + Optional.empty(), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new PniServiceIdentifier(singleDeviceAccountPni), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.of(Base64.getEncoder().encodeToString(UnidentifiedAccessUtil.getCombinedUnidentifiedAccessKey(List.of(singleDeviceAccount, multiDeviceAccount)))), + Optional.empty(), + 401, + Set.of(singleDeviceAccount, multiDeviceAccount)), + + Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement", + accountsByServiceIdentifier, + MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(new PniServiceIdentifier(singleDeviceAccountPni), Device.PRIMARY_ID, singleDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), Device.PRIMARY_ID, multiDevicePrimaryRegistrationId, new byte[48]), + new TestRecipient(new PniServiceIdentifier(multiDeviceAccountPni), (byte) (Device.PRIMARY_ID + 1), multiDeviceLinkedRegistrationId, new byte[48]))), + clock.instant().toEpochMilli(), + false, + false, + Optional.empty(), + Optional.of(AuthHelper.validGroupSendTokenHeader(serverSecretParams, + List.of(new PniServiceIdentifier(singleDeviceAccountPni), new PniServiceIdentifier(multiDeviceAccountPni)), + START_OF_DAY.plus(Duration.ofDays(1)))), + 200, + Set.of(singleDeviceAccount, multiDeviceAccount)) ); } - @Test - void sendMultiRecipientMessageMismatchedDevices() throws JsonProcessingException { - - final ServiceIdentifier serviceIdentifier = MULTI_DEVICE_ACI_ID; - - final byte extraDeviceId = MULTI_DEVICE_ID3 + 1; - - final List recipients = List.of( - new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48]), - new Recipient(serviceIdentifier, extraDeviceId, 1234, new byte[48])); - - // initialize our binary payload and create an input stream - final byte[] buffer = new byte[2048]; - final InputStream stream = initializeMultiPayload(recipients, buffer, true); - - // set up the entity to use in our PUT request - final Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - - // start building the request - final Invocation.Builder invocationBuilder = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", false) - .queryParam("ts", System.currentTimeMillis()) - .queryParam("story", false) - .queryParam("urgent", true) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); - - // make the PUT request - try (final Response response = invocationBuilder.put(entity)) { - assertEquals(409, response.getStatus()); - - final List mismatchedDevices = - SystemMapper.jsonMapper().readValue(response.readEntity(String.class), - SystemMapper.jsonMapper().getTypeFactory() - .constructCollectionType(List.class, AccountMismatchedDevices.class)); - - assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier, - new MismatchedDevices(Collections.emptyList(), List.of(extraDeviceId)))), - mismatchedDevices); - } - } - - @Test - void sendMultiRecipientMessageStaleDevices() throws JsonProcessingException { - final ServiceIdentifier serviceIdentifier = MULTI_DEVICE_ACI_ID; - final List recipients = List.of( - new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1 + 1, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2 + 1, new byte[48]), - new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3 + 1, new byte[48])); - - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - // InputStream stream = initializeMultiPayload(recipientUUID, buffer); - InputStream stream = initializeMultiPayload(recipients, buffer, true); - - // set up the entity to use in our PUT request - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - - // start building the request - final Invocation.Builder invocationBuilder = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", false) - .queryParam("ts", System.currentTimeMillis()) - .queryParam("story", false) - .queryParam("urgent", true) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); - - // make the PUT request - try (final Response response = invocationBuilder.put(entity)) { - assertEquals(410, response.getStatus()); - - final List staleDevices = - SystemMapper.jsonMapper().readValue(response.readEntity(String.class), - SystemMapper.jsonMapper().getTypeFactory() - .constructCollectionType(List.class, AccountStaleDevices.class)); - - assertEquals(1, staleDevices.size()); - assertEquals(serviceIdentifier, staleDevices.getFirst().uuid()); - assertEquals(Set.of(MULTI_DEVICE_ID1, MULTI_DEVICE_ID2, MULTI_DEVICE_ID3), - new HashSet<>(staleDevices.getFirst().devices().staleDevices())); - } - } - - @Test - void sendMultiRecipientMessageStoryRateLimited() { - final List recipients = List.of(new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); - // initialize our binary payload and create an input stream - byte[] buffer = new byte[2048]; - // InputStream stream = initializeMultiPayload(recipientUUID, buffer); - InputStream stream = initializeMultiPayload(recipients, buffer, true); - - // set up the entity to use in our PUT request - Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - - // start building the request - final Invocation.Builder invocationBuilder = resources - .getJerseyTest() - .target("/v1/messages/multi_recipient") - .queryParam("online", false) - .queryParam("ts", System.currentTimeMillis()) - .queryParam("story", true) - .queryParam("urgent", true) - .request() - .header(HttpHeaders.USER_AGENT, "test") - .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); - - when(rateLimiter.validateAsync(any(UUID.class))) - .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77)))); - - try (final Response response = invocationBuilder.put(entity)) { - assertEquals(429, response.getStatus()); - } - } - - @SuppressWarnings("SameParameterValue") - private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { - assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); - verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); - } - @SuppressWarnings("SameParameterValue") private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index a5772bcaf..1544e2635 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -10,23 +10,25 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyByte; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; -import com.google.protobuf.ByteString; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.UUID; -import org.apache.commons.lang3.RandomStringUtils; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; @@ -49,17 +51,21 @@ void setUp() { @CartesianTest void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent, - @CartesianTest.Values(booleans = {true, false}) final boolean onlineMessage, + @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent, @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException { - final boolean expectPushNotificationAttempt = !clientPresent && !onlineMessage; + final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; final UUID accountIdentifier = UUID.randomUUID(); final byte deviceId = Device.PRIMARY_ID; final Account account = mock(Account.class); final Device device = mock(Device.class); - final MessageProtos.Envelope message = generateRandomMessage(); + final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder() + .setEphemeral(ephemeral) + .setUrgent(urgent) + .build(); when(account.getUuid()).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); @@ -72,18 +78,61 @@ void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean c .when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean()); } - when(messagesManager.insert(eq(accountIdentifier), eq(deviceId), any())).thenReturn(clientPresent); + when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent)); - assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, onlineMessage)); + assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message))); - final MessageProtos.Envelope expectedMessage = onlineMessage + final MessageProtos.Envelope expectedMessage = ephemeral ? message.toBuilder().setEphemeral(true).build() : message.toBuilder().build(); - verify(messagesManager).insert(accountIdentifier, deviceId, expectedMessage); + verify(messagesManager).insert(accountIdentifier, Map.of(deviceId, expectedMessage)); if (expectPushNotificationAttempt) { - verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, expectedMessage.getUrgent()); + verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent); + } else { + verifyNoInteractions(pushNotificationManager); + } + } + + @CartesianTest + void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent, + @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent, + @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException { + + final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; + + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + + final Account account = mock(Account.class); + final Device device = mock(Device.class); + + when(account.getUuid()).thenReturn(accountIdentifier); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(device.getId()).thenReturn(deviceId); + + if (hasPushToken) { + when(device.getApnId()).thenReturn("apns-token"); + } else { + doThrow(NotPushRegisteredException.class) + .when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean()); + } + + when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) + .thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent)))); + + assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class), + Collections.emptyMap(), + System.currentTimeMillis(), + false, + ephemeral, + urgent) + .join()); + + if (expectPushNotificationAttempt) { + verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent); } else { verifyNoInteractions(pushNotificationManager); } @@ -123,14 +172,4 @@ private static List getDeliveryChannelName() { return arguments; } - - private MessageProtos.Envelope generateRandomMessage() { - return MessageProtos.Envelope.newBuilder() - .setClientTimestamp(System.currentTimeMillis()) - .setServerTimestamp(System.currentTimeMillis()) - .setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256))) - .setType(MessageProtos.Envelope.Type.CIPHERTEXT) - .setServerGuid(UUID.randomUUID().toString()) - .build(); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 3a94f104b..14787d02e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -104,7 +105,7 @@ void changeNumberNoMessages() throws Exception { changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); - verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); + verify(messageSender, never()).sendMessages(eq(account), any()); } @Test @@ -118,7 +119,7 @@ void changeNumberSetPrimaryDevicePrekey() throws Exception { changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); - verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); + verify(messageSender, never()).sendMessages(eq(account), any()); } @Test @@ -155,10 +156,15 @@ void changeNumberSetPrimaryDevicePrekeyAndSendMessages() throws Exception { verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); - verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); - final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + + assertEquals(1, envelopeCaptor.getValue().size()); + assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); @@ -203,10 +209,15 @@ void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); - verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); - final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + assertEquals(1, envelopeCaptor.getValue().size()); + assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); @@ -249,10 +260,15 @@ void changeNumberSameNumberSetPrimaryDevicePrekeyAndSendMessages() throws Except verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); - verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); - final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + assertEquals(1, envelopeCaptor.getValue().size()); + assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); @@ -291,10 +307,15 @@ void updatePniKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); - verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + + assertEquals(1, envelopeCaptor.getValue().size()); + assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); - final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); @@ -335,10 +356,15 @@ void updatePniKeysSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); - final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); - verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = + ArgumentCaptor.forClass(Map.class); + + verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + + assertEquals(1, envelopeCaptor.getValue().size()); + assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); - final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 610203367..f588d750f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -84,7 +84,7 @@ void setUp() throws Exception { messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC()); messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), - messageDeletionExecutorService); + messageDeletionExecutorService, Clock.systemUTC()); websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor(); @@ -143,7 +143,7 @@ void testScheduledPersistMessages() { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); - messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message); + messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message).join(); expectedMessages.add(message); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 95816a773..58a7e6417 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -358,7 +358,7 @@ private void insertMessages(final UUID accountUuid, final byte deviceId, final i .setServerGuid(messageGuid.toString()) .build(); - messagesCache.insert(messageGuid, accountUuid, deviceId, envelope); + messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java index f0fe7d7f5..f13b0644f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheGetItemsScriptTest.java @@ -40,7 +40,7 @@ void testCacheGetItemsScript() throws Exception { .setServerGuid(serverGuid) .build(); - insertScript.execute(destinationUuid, deviceId, envelope1); + insertScript.executeAsync(destinationUuid, deviceId, envelope1); final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript( REDIS_CLUSTER_EXTENSION.getRedisCluster()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java index 753f29f2a..c6bfe2d1f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertScriptTest.java @@ -41,7 +41,7 @@ void testCacheInsertScript() throws Exception { .setServerGuid(UUID.randomUUID().toString()) .build(); - insertScript.execute(destinationUuid, deviceId, envelope1); + insertScript.executeAsync(destinationUuid, deviceId, envelope1); assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId)); @@ -50,11 +50,11 @@ void testCacheInsertScript() throws Exception { .setServerGuid(UUID.randomUUID().toString()) .build(); - insertScript.execute(destinationUuid, deviceId, envelope2); + insertScript.executeAsync(destinationUuid, deviceId, envelope2); assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId)); - insertScript.execute(destinationUuid, deviceId, envelope1); + insertScript.executeAsync(destinationUuid, deviceId, envelope1); assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId), "Messages with same GUID should be deduplicated"); @@ -89,10 +89,10 @@ void returnPresence() throws IOException { final MessagesCacheInsertScript insertScript = new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster()); - assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() + assertFalse(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() .setServerTimestamp(Instant.now().getEpochSecond()) .setServerGuid(UUID.randomUUID().toString()) - .build())); + .build()).join()); final FaultTolerantPubSubClusterConnection pubSubClusterConnection = REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection(); @@ -100,9 +100,9 @@ void returnPresence() throws IOException { pubSubClusterConnection.usePubSubConnection(connection -> connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId))); - assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() + assertTrue(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() .setServerTimestamp(Instant.now().getEpochSecond()) .setServerGuid(UUID.randomUUID().toString()) - .build())); + .build()).join()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java index fb5a9b8d9..612c1a574 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest.java @@ -6,7 +6,9 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.lettuce.core.RedisCommandExecutionException; import java.util.ArrayList; @@ -14,8 +16,10 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CompletionException; import java.util.stream.Collectors; import java.util.stream.IntStream; +import io.lettuce.core.RedisException; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; @@ -39,8 +43,8 @@ void testInsert(final int count, final Map> destin REDIS_CLUSTER_EXTENSION.getRedisCluster()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); - insertMrmScript.execute(sharedMrmKey, - MessagesCacheTest.generateRandomMrmMessage(destinations)); + insertMrmScript.executeAsync(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(destinations)).join(); final int totalDevices = destinations.values().stream().mapToInt(List::size).sum(); final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() @@ -82,15 +86,17 @@ void testInsertDuplicateKey() throws Exception { REDIS_CLUSTER_EXTENSION.getRedisCluster()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); - insertMrmScript.execute(sharedMrmKey, - MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)); + insertMrmScript.executeAsync(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)).join(); - final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class, - () -> insertMrmScript.execute(sharedMrmKey, + final CompletionException completionException = assertThrows(CompletionException.class, + () -> insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), - Device.PRIMARY_ID))); + Device.PRIMARY_ID)).join()); - assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage()); + assertInstanceOf(RedisException.class, completionException.getCause()); + assertTrue(completionException.getCause().getMessage() + .contains(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java index 52f3b239f..db9c188c3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveByGuidScriptTest.java @@ -34,7 +34,7 @@ void testCacheRemoveByGuid() throws Exception { .setServerGuid(serverGuid.toString()) .build(); - insertScript.execute(destinationUuid, deviceId, envelope1); + insertScript.executeAsync(destinationUuid, deviceId, envelope1); final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript( REDIS_CLUSTER_EXTENSION.getRedisCluster()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java index b25b3456a..01dde38dd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveQueueScriptTest.java @@ -35,7 +35,7 @@ void testCacheRemoveQueueScript() throws Exception { .setServerGuid(UUID.randomUUID().toString()) .build(); - insertScript.execute(destinationUuid, deviceId, envelope1); + insertScript.executeAsync(destinationUuid, deviceId, envelope1); final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript( REDIS_CLUSTER_EXTENSION.getRedisCluster()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java index af3c823c3..301b869cf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheRemoveRecipientViewFromMrmDataScriptTest.java @@ -41,8 +41,7 @@ void testUpdateSingleKey(final Map> destinations) REDIS_CLUSTER_EXTENSION.getRedisCluster()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); - insertMrmScript.execute(sharedMrmKey, - MessagesCacheTest.generateRandomMrmMessage(destinations)); + insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(destinations)).join(); final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript( REDIS_CLUSTER_EXTENSION.getRedisCluster()); @@ -103,8 +102,8 @@ void testUpdateManyKeys(int keyCount) throws Exception { REDIS_CLUSTER_EXTENSION.getRedisCluster()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); - insertMrmScript.execute(sharedMrmKey, - MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)); + insertMrmScript.executeAsync(sharedMrmKey, + MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)).join(); sharedMrmKeys.add(sharedMrmKey); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 54d9b4a73..a3332e738 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -122,7 +122,7 @@ void tearDown() throws Exception { void testInsert(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, sealedSender))); + generateRandomMessage(messageGuid, sealedSender))).join(); } @Test @@ -130,8 +130,8 @@ void testDoubleInsertGuid() { final UUID duplicateGuid = UUID.randomUUID(); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); - messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); - messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); + messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join(); + messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join(); assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10) .count() @@ -149,7 +149,7 @@ void testRemoveByUUID(final boolean sealedSender) throws Exception { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS); @@ -175,12 +175,12 @@ void testRemoveBatchByUUID(final boolean sealedSender) throws Exception { for (final MessageProtos.Envelope message : messagesToRemove) { messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, - message); + message).join(); } for (final MessageProtos.Envelope message : messagesToPreserve) { messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, - message); + message).join(); } final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, @@ -197,7 +197,7 @@ void testHasMessages() { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); } @@ -208,7 +208,7 @@ void testHasMessagesAsync() { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); } @@ -223,7 +223,7 @@ void getOldestTimestamp() { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); assertEquals(expectedOldestTimestamp, messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block()); expectedMessages.add(message); @@ -248,7 +248,7 @@ void testGetMessages(final boolean sealedSender) throws Exception { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); expectedMessages.add(message); } @@ -262,7 +262,7 @@ void testGetMessages(final boolean sealedSender) throws Exception { final UUID message1Guid = UUID.randomUUID(); final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender); - messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1); + messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1).join(); final List get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1); assertEquals(List.of(message1), get1); @@ -272,7 +272,7 @@ void testGetMessages(final boolean sealedSender) throws Exception { final UUID message2Guid = UUID.randomUUID(); final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender); - messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2); + messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2).join(); assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1)); } @@ -287,7 +287,7 @@ void testGetMessagesPublisher(final boolean expectStale) throws Exception { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join(); expectedMessages.add(message); } @@ -295,7 +295,7 @@ void testGetMessagesPublisher(final boolean expectStale) throws Exception { final UUID ephemeralMessageGuid = UUID.randomUUID(); final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true) .toBuilder().setEphemeral(true).build(); - messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage); + messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage).join(); final Clock cacheClock; if (expectStale) { @@ -352,7 +352,7 @@ void testClearQueueForDevice(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join(); } } @@ -372,7 +372,7 @@ void testClearQueueForAccount(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join(); } } @@ -404,7 +404,7 @@ public void testGetQueuesToPersist(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, sealedSender)); + generateRandomMessage(messageGuid, sealedSender)).join(); final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID); assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty()); @@ -427,7 +427,7 @@ void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) { final byte[] sharedMrmDataKey; if (sharedMrmKeyPresent) { - sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); + sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join(); } else { sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8); } @@ -440,7 +440,7 @@ void testMultiRecipientMessage(final boolean sharedMrmKeyPresent) { .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .clearContent() .build(); - messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message); + messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message).join(); assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); @@ -487,13 +487,13 @@ void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, new AciServiceIdentifier(destinationUuid), true); - messagesCache.insert(messageGuid, destinationUuid, deviceId, message); + messagesCache.insert(messageGuid, destinationUuid, deviceId, message).join(); final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId); final byte[] sharedMrmDataKey; if (sharedMrmKeyPresent) { - sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); + sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join(); } else { sharedMrmDataKey = new byte[]{1}; } @@ -505,7 +505,7 @@ void testGetMessagesToPersist(final boolean sharedMrmKeyPresent) { .clearContent() .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .build(); - messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage); + messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join(); final List messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index a58cd1963..c7691c0c8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -7,22 +7,42 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.protobuf.ByteString; +import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.signal.libsignal.protocol.InvalidMessageException; +import org.signal.libsignal.protocol.InvalidVersionException; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper; +import org.whispersystems.textsecuregcm.tests.util.TestRecipient; +import org.whispersystems.textsecuregcm.util.TestClock; import reactor.core.publisher.Mono; class MessagesManagerTest { @@ -31,8 +51,15 @@ class MessagesManagerTest { private final MessagesCache messagesCache = mock(MessagesCache.class); private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, - reportMessageManager, Executors.newSingleThreadExecutor()); + reportMessageManager, Executors.newSingleThreadExecutor(), CLOCK); + + @BeforeEach + void setUp() { + when(messagesCache.insert(any(), any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(true)); + } @Test void insert() { @@ -43,7 +70,7 @@ void insert() { final UUID destinationUuid = UUID.randomUUID(); - messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message); + messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, message)); verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class)); @@ -51,11 +78,113 @@ void insert() { .setSourceServiceId(destinationUuid.toString()) .build(); - messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage); + messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, syncMessage)); verifyNoMoreInteractions(reportMessageManager); } + @Test + void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException { + final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID()); + final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final ServiceIdentifier unresolvedAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + final Account singleDeviceAccount = mock(Account.class); + final Account multiDeviceAccount = mock(Account.class); + + when(singleDeviceAccount.getIdentifier(IdentityType.ACI)) + .thenReturn(singleDeviceAccountAciServiceIdentifier.uuid()); + + when(multiDeviceAccount.getIdentifier(IdentityType.ACI)) + .thenReturn(multiDeviceAccountAciServiceIdentifier.uuid()); + + final byte[] multiRecipientMessageBytes = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(singleDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]), + new TestRecipient(multiDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 2, new byte[48]), + new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]), + new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]), + new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48]) + )); + + final SealedSenderMultiRecipientMessage multiRecipientMessage = + SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes); + + final Map resolvedRecipients = new HashMap<>(); + + multiRecipientMessage.getRecipients().forEach(((serviceId, recipient) -> { + if (serviceId.getRawUUID().equals(singleDeviceAccountAciServiceIdentifier.uuid()) || + serviceId.getRawUUID().equals(singleDeviceAccountPniServiceIdentifier.uuid())) { + resolvedRecipients.put(recipient, singleDeviceAccount); + } else if (serviceId.getRawUUID().equals(multiDeviceAccountAciServiceIdentifier.uuid())) { + resolvedRecipients.put(recipient, multiDeviceAccount); + } + })); + + final Map> expectedPresenceByAccountAndDeviceId = Map.of( + singleDeviceAccount, Map.of(Device.PRIMARY_ID, true), + multiDeviceAccount, Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true) + ); + + final Map> presenceByAccountIdentifierAndDeviceId = Map.of( + singleDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, true), + multiDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true) + ); + + final byte[] sharedMrmKey = "shared-mrm-key".getBytes(StandardCharsets.UTF_8); + + when(messagesCache.insertSharedMultiRecipientMessagePayload(multiRecipientMessage)) + .thenReturn(CompletableFuture.completedFuture(sharedMrmKey)); + + when(messagesCache.insert(any(), any(), anyByte(), any())) + .thenAnswer(invocation -> { + final UUID accountIdentifier = invocation.getArgument(1); + final byte deviceId = invocation.getArgument(2); + + return CompletableFuture.completedFuture( + presenceByAccountIdentifierAndDeviceId.getOrDefault(accountIdentifier, Collections.emptyMap()) + .getOrDefault(deviceId, false)); + }); + + final long clientTimestamp = System.currentTimeMillis(); + final boolean isStory = ThreadLocalRandom.current().nextBoolean(); + final boolean isEphemeral = ThreadLocalRandom.current().nextBoolean(); + final boolean isUrgent = ThreadLocalRandom.current().nextBoolean(); + + final Envelope prototypeExpectedMessage = Envelope.newBuilder() + .setType(Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(clientTimestamp) + .setServerTimestamp(CLOCK.millis()) + .setStory(isStory) + .setEphemeral(isEphemeral) + .setUrgent(isUrgent) + .setSharedMrmKey(ByteString.copyFrom(sharedMrmKey)) + .build(); + + assertEquals(expectedPresenceByAccountAndDeviceId, + messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join()); + + verify(messagesCache).insert(any(), + eq(singleDeviceAccountAciServiceIdentifier.uuid()), + eq(Device.PRIMARY_ID), + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); + + verify(messagesCache).insert(any(), + eq(singleDeviceAccountAciServiceIdentifier.uuid()), + eq(Device.PRIMARY_ID), + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build())); + + verify(messagesCache).insert(any(), + eq(multiDeviceAccountAciServiceIdentifier.uuid()), + eq((byte) (Device.PRIMARY_ID + 1)), + eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build())); + + verify(messagesCache, never()).insert(any(), + eq(unresolvedAccountAciServiceIdentifier.uuid()), + anyByte(), + any()); + } + @ParameterizedTest @CsvSource({ "false, false, false", diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java index 89cd11b02..cc96a5405 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageDynamoDbTest.java @@ -29,6 +29,7 @@ class ReportMessageDynamoDbTest { void setUp() { this.reportMessageDynamoDb = new ReportMessageDynamoDb( DYNAMO_DB_EXTENSION.getDynamoDbClient(), + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.REPORT_MESSAGES.tableName(), Duration.ofDays(1)); } @@ -44,8 +45,8 @@ void testStore() { () -> assertFalse(reportMessageDynamoDb.remove(hash2)) ); - reportMessageDynamoDb.store(hash1); - reportMessageDynamoDb.store(hash2); + reportMessageDynamoDb.store(hash1).join(); + reportMessageDynamoDb.store(hash2).join(); assertAll("both hashes should be found", () -> assertTrue(reportMessageDynamoDb.remove(hash1)), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java index 2cc6d9adb..df882b16b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -68,8 +69,8 @@ void testStore() { verify(reportMessageDynamoDb).store(any()); - doThrow(RuntimeException.class) - .when(reportMessageDynamoDb).store(any()); + when(reportMessageDynamoDb.store(any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())); assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MultiRecipientMessageHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MultiRecipientMessageHelper.java new file mode 100644 index 000000000..691941a01 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MultiRecipientMessageHelper.java @@ -0,0 +1,92 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import java.nio.ByteBuffer; +import java.util.List; + +public class MultiRecipientMessageHelper { + + private MultiRecipientMessageHelper() { + } + + public static byte[] generateMultiRecipientMessage(final List recipients) { + return generateMultiRecipientMessage(recipients, 32); + } + + public static byte[] generateMultiRecipientMessage(final List recipients, final int sharedPayloadSize) { + if (sharedPayloadSize < 32) { + throw new IllegalArgumentException("Shared payload size must be at least 32 bytes"); + } + + final ByteBuffer buffer = ByteBuffer.allocate(payloadSize(recipients, sharedPayloadSize)); + + // first write the header + buffer.put((byte) 0x23); // version byte + + // count varint + writeVarint(buffer, recipients.size()); + + recipients.forEach(recipient -> { + buffer.put(recipient.uuid().toFixedWidthByteArray()); + + assert recipient.deviceIds().length == recipient.registrationIds().length; + + for (int i = 0; i < recipient.deviceIds().length; i++) { + final int hasMore = i == recipient.deviceIds().length - 1 ? 0 : 0x8000; + buffer.put(recipient.deviceIds()[i]); // device id (1 byte) + buffer.putShort((short) (recipient.registrationIds()[i] | hasMore)); // registration id (2 bytes) + } + + buffer.put(recipient.perRecipientKeyMaterial()); // key material (48 bytes) + }); + + // now write the actual message body (empty for now) + writeVarint(buffer, sharedPayloadSize); + buffer.put(new byte[sharedPayloadSize]); + + return buffer.array(); + } + + private static void writeVarint(final ByteBuffer buffer, long n) { + if (n < 0) { + throw new IllegalArgumentException(); + } + + while (n >= 0x80) { + buffer.put ((byte) (n & 0x7F | 0x80)); + n >>= 7; + } + buffer.put((byte) (n & 0x7F)); + } + + private static int payloadSize(final List recipients, final int sharedPayloadSize) { + final int fixedBytesPerRecipient = 17 // Service identifier length + + 48; // Per-recipient key material + + final int bytesForDevices = 3 * recipients.stream() + .mapToInt(recipient -> recipient.deviceIds().length) + .sum(); + + return 1 // Version byte + + varintLength(recipients.size()) + + (recipients.size() * fixedBytesPerRecipient) + + bytesForDevices + + varintLength(sharedPayloadSize) + + sharedPayloadSize; + } + + private static int varintLength(long n) { + int length = 0; + + while (n >= 0x80) { + length += 1; + n >>= 7; + } + + return length + 1; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestRecipient.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestRecipient.java new file mode 100644 index 000000000..1e1e7770a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestRecipient.java @@ -0,0 +1,22 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; + +public record TestRecipient(ServiceIdentifier uuid, + byte[] deviceIds, + int[] registrationIds, + byte[] perRecipientKeyMaterial) { + + public TestRecipient(ServiceIdentifier uuid, + byte deviceId, + int registrationId, + byte[] perRecipientKeyMaterial) { + + this(uuid, new byte[]{deviceId}, new int[]{registrationId}, perRecipientKeyMaterial); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 98b89746d..efd6c4d1c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -132,7 +132,7 @@ void tearDown() throws Exception { void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), @@ -164,7 +164,7 @@ void testProcessStoredMessages(final int persistedMessageCount, final int cached final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); expectedMessages.add(envelope); } @@ -220,7 +220,7 @@ void testProcessStoredMessages(final int persistedMessageCount, final int cached void testProcessStoredMessagesClientClosed() { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), @@ -253,7 +253,7 @@ void testProcessStoredMessagesClientClosed() { for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); expectedMessages.add(envelope); } @@ -289,7 +289,7 @@ void testProcessStoredMessagesClientClosed() { void testProcessStoredMessagesSendFutureTimeout() { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), + new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()), new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), @@ -323,7 +323,7 @@ void testProcessStoredMessagesSendFutureTimeout() { for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); expectedMessages.add(envelope); } From 7c17a4067c75bcae8bc4a796067f41b0644b74e5 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 31 Jan 2025 10:34:14 -0500 Subject: [PATCH 05/12] Update to the latest version of the spam filter --- spam-filter | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spam-filter b/spam-filter index 1573cb363..28a2cc0a1 160000 --- a/spam-filter +++ b/spam-filter @@ -1 +1 @@ -Subproject commit 1573cb3636e038c56a5e2a02f45708c9e694a589 +Subproject commit 28a2cc0a1d1f14846548b9872e954290870b135f From 06388b514c304f7a5118683cf10c8d799755e942 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Thu, 30 Jan 2025 17:44:19 -0600 Subject: [PATCH 06/12] Add timeout to GitHub test action --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eba44112f..69250a14c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,6 +9,7 @@ jobs: build: runs-on: ubuntu-latest container: ubuntu:22.04 + timeout-minutes: 20 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 From 09eb42e5c6c7d030f33068eaef5ac696449ee988 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Tue, 28 Jan 2025 17:51:17 -0600 Subject: [PATCH 07/12] Add tag for requests made with libsignal --- .../metrics/MetricsRequestEventListener.java | 2 +- .../metrics/UserAgentTagUtil.java | 26 +++++++++++++++---- .../MetricsRequestEventListenerTest.java | 11 +++++--- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java index 8b94ebe78..e567a6c6b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListener.java @@ -81,7 +81,7 @@ public void onEvent(final RequestEvent event) { userAgent = userAgentValues != null && !userAgentValues.isEmpty() ? userAgentValues.get(0) : null; } - tags.add(UserAgentTagUtil.getPlatformTag(userAgent)); + tags.addAll(UserAgentTagUtil.getLibsignalAndPlatformTags(userAgent)); meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/UserAgentTagUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/UserAgentTagUtil.java index 7ca8ccfc1..6a4d09ddd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/UserAgentTagUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/UserAgentTagUtil.java @@ -5,14 +5,10 @@ package org.whispersystems.textsecuregcm.metrics; -import com.vdurmont.semver4j.Semver; import io.micrometer.core.instrument.Tag; -import java.util.Collections; -import java.util.Map; +import java.util.List; import java.util.Optional; -import java.util.Set; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; @@ -24,6 +20,7 @@ public class UserAgentTagUtil { public static final String PLATFORM_TAG = "platform"; public static final String VERSION_TAG = "clientVersion"; + public static final String LIBSIGNAL_TAG = "libsignal"; private UserAgentTagUtil() { } @@ -52,4 +49,23 @@ public static Optional getClientVersionTag(final String userAgentString, fi return Optional.empty(); } + + public static List getLibsignalAndPlatformTags(final String userAgentString) { + String platform; + boolean libsignal; + + try { + final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); + platform = userAgent.getPlatform().name().toLowerCase(); + libsignal = userAgent.getAdditionalSpecifiers() + .map(additionalSpecifiers -> additionalSpecifiers.contains("libsignal")) + .orElse(false); + } catch (final UnrecognizedUserAgentException e) { + platform = "unrecognized"; + libsignal = false; + } + + return List.of(Tag.of(PLATFORM_TAG, platform), Tag.of(LIBSIGNAL_TAG, String.valueOf(libsignal))); + } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 9ba5a76bb..232dbdd6c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -90,7 +90,7 @@ void testOnEvent() { final ContainerRequest request = mock(ContainerRequest.class); when(request.getMethod()).thenReturn(method); when(request.getRequestHeader(HttpHeaders.USER_AGENT)).thenReturn( - Collections.singletonList("Signal-Android/4.53.7 (Android 8.1)")); + Collections.singletonList("Signal-Android/7.6.2 Android/34 libsignal/0.46.0")); final ContainerResponse response = mock(ContainerResponse.class); when(response.getStatus()).thenReturn(statusCode); @@ -116,12 +116,13 @@ void testOnEvent() { tags.add(tag); } - assertEquals(5, tags.size()); + assertEquals(6, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, path))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, method))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(statusCode)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "true"))); } @Test @@ -178,12 +179,13 @@ void testActualRouteMessageSuccess() throws IOException { tags.add(tag); } - assertEquals(5, tags.size()); + assertEquals(6, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } @Test @@ -238,12 +240,13 @@ void testActualRouteMessageSuccessNoUserAgent() throws IOException { tags.add(tag); } - assertEquals(5, tags.size()); + assertEquals(6, tags.size()); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.PATH_TAG, "/v1/test/hello"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.STATUS_CODE_TAG, String.valueOf(200)))); assertTrue(tags.contains(Tag.of(MetricsRequestEventListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); } private static SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor responseCaptor) From c84d96abeeef752afdaccdf19669a0fd9a88a8c8 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 20 Dec 2024 11:59:42 -0600 Subject: [PATCH 08/12] Remove deprecated svr3Credentials field --- .../auth/RegistrationLockVerificationManager.java | 3 +-- .../textsecuregcm/entities/RegistrationLockFailure.java | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java index 425f64ef8..ab0d8f9a4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java @@ -169,8 +169,7 @@ public void verifyRegistrationLock(final Account account, @Nullable final String throw new WebApplicationException(Response.status(FAILURE_HTTP_STATUS) .entity(new RegistrationLockFailure( existingRegistrationLock.getTimeRemaining().toMillis(), - svr2FailureCredentials(existingRegistrationLock, updatedAccount), - null)) + svr2FailureCredentials(existingRegistrationLock, updatedAccount))) .build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationLockFailure.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationLockFailure.java index 80b51202e..13d3345d1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationLockFailure.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationLockFailure.java @@ -18,8 +18,5 @@ public record RegistrationLockFailure( long timeRemaining, @Schema(description = "Credentials that can be used with SVR2") @Nullable - ExternalServiceCredentials svr2Credentials, - @Deprecated - @Nullable - ExternalServiceCredentials svr3Credentials) { + ExternalServiceCredentials svr2Credentials) { } From 70ce6eff9e9bbd1a715192b00d358f6e9e2a5fcd Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 31 Jan 2025 12:50:14 -0500 Subject: [PATCH 09/12] Include `ephemeral` flag in individual messages --- .../textsecuregcm/controllers/MessageController.java | 1 + .../textsecuregcm/entities/IncomingMessage.java | 2 ++ .../textsecuregcm/entities/OutgoingMessageEntityTest.java | 7 ++++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 02a57caa1..445127b27 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -403,6 +403,7 @@ public Response sendMessage(@ReadOnly @Auth final Optional source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null), messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(), isStory, + messages.online(), messages.urgent(), reportSpamToken.orElse(null)); } catch (final IllegalArgumentException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 37cd1d886..3422d44a5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -19,6 +19,7 @@ public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIden @Nullable Byte sourceDeviceId, final long timestamp, final boolean story, + final boolean ephemeral, final boolean urgent, @Nullable byte[] reportSpamToken) { @@ -35,6 +36,7 @@ public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIden .setServerTimestamp(System.currentTimeMillis()) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) .setStory(story) + .setEphemeral(ephemeral) .setUrgent(urgent); if (sourceAccount != null && sourceDeviceId != null) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index 6f2da5044..4544acd15 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -64,8 +64,6 @@ static ArgumentSets roundTripThroughEnvelope() { @Test void entityPreservesEnvelope() { - final Random random = new Random(); - final byte[] reportSpamToken = TestRandomUtil.nextBytes(8); final Account account = new Account(); @@ -79,11 +77,14 @@ void entityPreservesEnvelope() { (byte) 123, System.currentTimeMillis(), false, + false, true, reportSpamToken); MessageProtos.Envelope envelope = baseEnvelope.toBuilder().setServerGuid(UUID.randomUUID().toString()).build(); - assertEquals(envelope, OutgoingMessageEntity.fromEnvelope(envelope).toEnvelope()); + // Note that outgoing message entities don't have an "ephemeral"/"online" flag + assertEquals(envelope.toBuilder().clearEphemeral().build(), + OutgoingMessageEntity.fromEnvelope(envelope).toEnvelope()); } } From 6545bb9edb65afdaf7dec4c5da06ba1cdf5aef80 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 31 Jan 2025 12:58:16 -0500 Subject: [PATCH 10/12] Update to the latest version of the spam filter --- spam-filter | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spam-filter b/spam-filter index 28a2cc0a1..0e075eac1 160000 --- a/spam-filter +++ b/spam-filter @@ -1 +1 @@ -Subproject commit 28a2cc0a1d1f14846548b9872e954290870b135f +Subproject commit 0e075eac1589e59df962fc4bc4d2fc212342ba42 From e4b0f3ced5867a76425b7dbb6523db3e12cf7c94 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Wed, 5 Feb 2025 13:48:07 -0600 Subject: [PATCH 11/12] Use HTTP status code if FCM error code is unavailable --- .../java/org/whispersystems/textsecuregcm/push/FcmSender.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/FcmSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/FcmSender.java index 8074803ba..53b7496ef 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/FcmSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/FcmSender.java @@ -104,6 +104,8 @@ public CompletableFuture sendNotification(PushNotifi if (firebaseMessagingException.getMessagingErrorCode() != null) { errorCode = firebaseMessagingException.getMessagingErrorCode().name(); + } else if (firebaseMessagingException.getHttpResponse() != null) { + errorCode = "http" + firebaseMessagingException.getHttpResponse().getStatusCode(); } else { logger.warn("Received an FCM exception with no error code", firebaseMessagingException); errorCode = "unknown"; From 5d062285c20a52f8e48c7c92c412aedecaf50073 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Wed, 5 Feb 2025 12:26:47 -0800 Subject: [PATCH 12/12] Filter to block old REST API for specified client versions --- .../textsecuregcm/WhisperServerService.java | 6 + .../dynamic/DynamicConfiguration.java | 10 ++ .../filters/RestDeprecationFilter.java | 88 ++++++++++++++ .../filters/RestDeprecationFilterTest.java | 111 ++++++++++++++++++ .../util/FakeDynamicConfigurationManager.java | 24 ++++ 5 files changed, 239 insertions(+) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilterTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/FakeDynamicConfigurationManager.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index f96108adf..18dec2bf8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -141,6 +141,7 @@ import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; +import org.whispersystems.textsecuregcm.filters.RestDeprecationFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.geo.MaxMindDatabaseManager; import org.whispersystems.textsecuregcm.grpc.AccountsAnonymousGrpcService; @@ -1001,7 +1002,12 @@ protected void configureServer(final ServerBuilder serverBuilder) { metricsHttpChannelListener.configure(environment); final MessageMetrics messageMetrics = new MessageMetrics(); + // BufferingInterceptor is needed on the base environment but not the WebSocketEnvironment, + // because we handle serialization of http responses on the websocket on our own and can + // compute content lengths without it environment.jersey().register(new BufferingInterceptor()); + environment.jersey().register(new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager)); + environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-")); environment.jersey().register(new RateLimitByIpFilter(rateLimiters)); environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java index d6ec9f970..25f4e1411 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.configuration.dynamic; import com.fasterxml.jackson.annotation.JsonProperty; +import com.vdurmont.semver4j.Semver; import jakarta.validation.Valid; import java.util.Collections; import java.util.HashMap; @@ -13,6 +14,7 @@ import java.util.Map; import java.util.Optional; import org.whispersystems.textsecuregcm.limits.RateLimiterConfig; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; public class DynamicConfiguration { @@ -72,6 +74,10 @@ public class DynamicConfiguration { @Valid List svrStatusCodesToIgnoreForAccountDeletion = Collections.emptyList(); + @JsonProperty + @Valid + Map minimumRestFreeVersion = Map.of(); + public Optional getExperimentEnrollmentConfiguration( final String experimentName) { return Optional.ofNullable(experiments.get(experimentName)); @@ -130,4 +136,8 @@ public List getSvrStatusCodesToIgnoreForAccountDeletion() { return svrStatusCodesToIgnoreForAccountDeletion; } + public Map minimumRestFreeVersion() { + return minimumRestFreeVersion; + } + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java new file mode 100644 index 000000000..cf363960d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import com.google.common.net.HttpHeaders; +import com.vdurmont.semver4j.Semver; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.core.SecurityContext; +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgent; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; + +public class RestDeprecationFilter implements ContainerRequestFilter { + + private static final String EXPERIMENT_NAME = "restDeprecation"; + private static final String DEPRECATED_REST_COUNTER_NAME = MetricsUtil.name(RestDeprecationFilter.class, "blockedRestRequest"); + + private static final Logger log = LoggerFactory.getLogger(RestDeprecationFilter.class); + + final DynamicConfigurationManager dynamicConfigurationManager; + final ExperimentEnrollmentManager experimentEnrollmentManager; + + public RestDeprecationFilter( + final DynamicConfigurationManager dynamicConfigurationManager, + final ExperimentEnrollmentManager experimentEnrollmentManager) { + this.dynamicConfigurationManager = dynamicConfigurationManager; + this.experimentEnrollmentManager = experimentEnrollmentManager; + } + + @Override + public void filter(final ContainerRequestContext requestContext) throws IOException { + + final SecurityContext securityContext = requestContext.getSecurityContext(); + + if (securityContext == null || securityContext.getUserPrincipal() == null) { + // We can't check if an unauthenticated request is in the experiment + return; + } + + if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice ad) { + if (!experimentEnrollmentManager.isEnrolled(ad.getAccount().getUuid(), EXPERIMENT_NAME)) { + return; + } + } else { + log.error("Security context was not null but user principal was of type {}", securityContext.getUserPrincipal().getClass().getName()); + return; + } + + final Map minimumRestFreeVersion = dynamicConfigurationManager.getConfiguration().minimumRestFreeVersion(); + final String userAgentString = requestContext.getHeaderString(HttpHeaders.USER_AGENT); + + try { + final UserAgent userAgent = UserAgentUtil.parseUserAgentString(userAgentString); + final ClientPlatform platform = userAgent.getPlatform(); + final Semver version = userAgent.getVersion(); + if (!minimumRestFreeVersion.containsKey(platform)) { + return; + } + if (version.isGreaterThanOrEqualTo(minimumRestFreeVersion.get(platform))) { + Metrics.counter( + DEPRECATED_REST_COUNTER_NAME, Tags.of("platform", platform.name().toLowerCase(), "version", version.toString())) + .increment(); + throw new WebApplicationException("use websockets", 498); + } + } catch (final UnrecognizedUserAgentException e) { + return; // at present we're only interested in experimenting on known clients + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilterTest.java new file mode 100644 index 000000000..707d21a69 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilterTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.net.HttpHeaders; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.SecurityContext; +import java.net.URI; +import java.util.UUID; +import org.glassfish.jersey.server.ContainerRequest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.tests.util.FakeDynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +class RestDeprecationFilterTest { + + @Test + void testNoConfig() throws Exception { + final DynamicConfigurationManager dynamicConfigurationManager = + new FakeDynamicConfigurationManager<>(new DynamicConfiguration()); + final ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); + + final RestDeprecationFilter filter = new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager); + + final Account account = new Account(); + account.setUuid(UUID.randomUUID()); + final SecurityContext securityContext = mock(SecurityContext.class); + when(securityContext.getUserPrincipal()).thenReturn(new AuthenticatedDevice(account, new Device())); + final ContainerRequest req = new ContainerRequest(null, new URI("/some/uri"), "GET", securityContext, null, null); + req.getHeaders().add(HttpHeaders.USER_AGENT, "Signal-Android/100.0.0"); + + filter.filter(req); + } + + @Test + void testOldClient() throws Exception { + final DynamicConfiguration config = SystemMapper.yamlMapper().readValue( + """ + minimumRestFreeVersion: + ANDROID: 200.0.0 + experiments: + restDeprecation: + uuidEnrollmentPercentage: 100 + """, + DynamicConfiguration.class); + final DynamicConfigurationManager dynamicConfigurationManager = new FakeDynamicConfigurationManager<>(config); + final ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); + + final RestDeprecationFilter filter = new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager); + + final Account account = new Account(); + account.setUuid(UUID.randomUUID()); + final SecurityContext securityContext = mock(SecurityContext.class); + when(securityContext.getUserPrincipal()).thenReturn(new AuthenticatedDevice(account, new Device())); + final ContainerRequest req = new ContainerRequest(null, new URI("/some/uri"), "GET", securityContext, null, null); + req.getHeaders().add(HttpHeaders.USER_AGENT, "Signal-Android/100.0.0"); + + filter.filter(req); + } + + @Test + void testBlocking() throws Exception { + final DynamicConfiguration config = SystemMapper.yamlMapper().readValue( + """ + minimumRestFreeVersion: + ANDROID: 10.10.10 + experiments: + restDeprecation: + enrollmentPercentage: 100 + """, + DynamicConfiguration.class); + final DynamicConfigurationManager dynamicConfigurationManager = new FakeDynamicConfigurationManager<>(config); + final ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); + + final RestDeprecationFilter filter = new RestDeprecationFilter(dynamicConfigurationManager, experimentEnrollmentManager); + + final Account account = new Account(); + account.setUuid(UUID.randomUUID()); + final SecurityContext securityContext = mock(SecurityContext.class); + when(securityContext.getUserPrincipal()).thenReturn(new AuthenticatedDevice(account, new Device())); + final ContainerRequest req = new ContainerRequest(null, new URI("/some/path"), "GET", securityContext, null, null); + + req.getHeaders().putSingle(HttpHeaders.USER_AGENT, "Signal-Android/10.9.15"); + filter.filter(req); + + req.getHeaders().putSingle(HttpHeaders.USER_AGENT, "Signal-Android/10.10.9"); + filter.filter(req); + + req.getHeaders().putSingle(HttpHeaders.USER_AGENT, "Signal-Android/10.10.10"); + assertThrows(WebApplicationException.class, () -> filter.filter(req)); + + req.getHeaders().putSingle(HttpHeaders.USER_AGENT, "Signal-Android/100.0.0"); + assertThrows(WebApplicationException.class, () -> filter.filter(req)); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/FakeDynamicConfigurationManager.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/FakeDynamicConfigurationManager.java new file mode 100644 index 000000000..94e0ae7eb --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/FakeDynamicConfigurationManager.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; + +public class FakeDynamicConfigurationManager extends DynamicConfigurationManager { + + T staticConfiguration; + + public FakeDynamicConfigurationManager(T staticConfiguration) { + super(null, (Class) staticConfiguration.getClass()); + this.staticConfiguration = staticConfiguration; + } + + @Override + public T getConfiguration() { + return staticConfiguration; + } + +}