Skip to content

Commit

Permalink
Merge branch 'signalapp:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
offsoc authored Feb 8, 2025
2 parents 2f579cf + 794e254 commit d4f672b
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 386 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private KeyTransparencyServiceConfiguration keyTransparencyService;

@JsonProperty
private boolean logMessageDeliveryLoops;

public TlsKeyStoreConfiguration getTlsKeyStoreConfiguration() {
return tlsKeyStore;
}
Expand Down Expand Up @@ -558,4 +561,9 @@ public ExternalRequestFilterConfiguration getExternalRequestFilterConfiguration(
public KeyTransparencyServiceConfiguration getKeyTransparencyServiceConfiguration() {
return keyTransparencyService;
}

public boolean logMessageDeliveryLoops() {
return logMessageDeliveryLoops;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
import org.whispersystems.textsecuregcm.controllers.ArchiveController;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4;
import org.whispersystems.textsecuregcm.controllers.CallLinkController;
import org.whispersystems.textsecuregcm.controllers.CallRoutingController;
import org.whispersystems.textsecuregcm.controllers.CallRoutingControllerV2;
import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.controllers.ChallengeController;
Expand Down Expand Up @@ -165,10 +164,12 @@
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.CardinalityEstimator;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.limits.NoopMessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.RedisMessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
Expand Down Expand Up @@ -656,7 +657,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
Subscriptions subscriptions = new Subscriptions(
config.getDynamoDbTables().getSubscriptions().getTableName(), dynamoDbAsyncClient);
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor =
new MessageDeliveryLoopMonitor(rateLimitersCluster);
config.logMessageDeliveryLoops() ? new RedisMessageDeliveryLoopMonitor(rateLimitersCluster) : new NoopMessageDeliveryLoopMonitor();

disconnectionRequestManager.addListener(webSocketConnectionEventManager);

Expand Down Expand Up @@ -1116,7 +1117,6 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
new AttachmentControllerV4(rateLimiters, gcsAttachmentGenerator, tusAttachmentGenerator,
experimentEnrollmentManager),
new ArchiveController(backupAuthManager, backupManager),
new CallRoutingController(rateLimiters, callRouter, turnTokenGenerator, experimentEnrollmentManager, cloudflareTurnCredentialsManager),
new CallRoutingControllerV2(rateLimiters, callRouter, turnTokenGenerator, experimentEnrollmentManager, cloudflareTurnCredentialsManager),
new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(),
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
public class CallRoutingControllerV2 {

private static final Counter INVALID_IP_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "invalidIP"));
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingController.class, "cloudflareTurnError"));
private static final Counter CLOUDFLARE_TURN_ERROR_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "cloudflareTurnError"));
private final RateLimiters rateLimiters;
private final TurnCallRouter turnCallRouter;
private final TurnTokenGenerator tokenGenerator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,28 @@

package org.whispersystems.textsecuregcm.entities;

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;

@Schema(description = """
Represents a request from a new device to restore account data by some method.
""")
public record RestoreAccountRequest(
@NotNull
@Schema(description = "The method by which the new device has requested account data restoration")
Method method) {
Method method,

@Schema(description = "Additional data to use to bootstrap a connection between devices, in standard unpadded base64.",
implementation = String.class)
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
@Size(max = 4096)
@Nullable byte[] deviceTransferBootstrap) {

public enum Method {
@Schema(description = "Restore account data from a remote message history backup")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,8 @@
package org.whispersystems.textsecuregcm.limits;

import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;

public class MessageDeliveryLoopMonitor {

private final ClusterLuaScript getDeliveryAttemptsScript;

private static final Duration DELIVERY_ATTEMPTS_COUNTER_TTL = Duration.ofHours(1);
private static final int DELIVERY_LOOP_THRESHOLD = 5;

private static final Logger logger = LoggerFactory.getLogger(MessageDeliveryLoopMonitor.class);

public MessageDeliveryLoopMonitor(final FaultTolerantRedisClusterClient rateLimitCluster) {
try {
getDeliveryAttemptsScript =
ClusterLuaScript.fromResource(rateLimitCluster, "lua/get_delivery_attempt_count.lua", ScriptOutputType.INTEGER);
} catch (final IOException e) {
throw new UncheckedIOException("Failed to load 'get delivery attempt count' script", e);
}
}

public interface MessageDeliveryLoopMonitor {
/**
* Records an attempt to deliver a message with the given GUID to the given account/device pair and returns the number
* of consecutive attempts to deliver the same message and logs a warning if the message appears to be in a delivery
Expand All @@ -44,29 +16,5 @@ public MessageDeliveryLoopMonitor(final FaultTolerantRedisClusterClient rateLimi
* @param userAgent the User-Agent header supplied by the caller
* @param context a human-readable string identifying the mechanism of message delivery (e.g. "rest" or "websocket")
*/
public void recordDeliveryAttempt(final UUID accountIdentifier,
final byte deviceId,
final UUID messageGuid,
final String userAgent,
final String context) {

incrementDeliveryAttemptCount(accountIdentifier, deviceId, messageGuid)
.thenAccept(deliveryAttemptCount -> {
if (deliveryAttemptCount == DELIVERY_LOOP_THRESHOLD) {
logger.warn("Detected loop delivering message {} via {} to {}:{} ({})",
messageGuid, context, accountIdentifier, deviceId, userAgent);
}
});
}

@VisibleForTesting
CompletableFuture<Long> incrementDeliveryAttemptCount(final UUID accountIdentifier, final byte deviceId, final UUID messageGuid) {
final String firstMessageGuidKey = "firstMessageGuid::{" + accountIdentifier + ":" + deviceId + "}";
final String deliveryAttemptsKey = "firstMessageDeliveryAttempts::{" + accountIdentifier + ":" + deviceId + "}";

return getDeliveryAttemptsScript.executeAsync(
List.of(firstMessageGuidKey, deliveryAttemptsKey),
List.of(messageGuid.toString(), String.valueOf(DELIVERY_ATTEMPTS_COUNTER_TTL.toSeconds())))
.thenApply(result -> (long) result);
}
void recordDeliveryAttempt(UUID accountIdentifier, byte deviceId, UUID messageGuid, String userAgent, String context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.whispersystems.textsecuregcm.limits;

import java.util.UUID;

public class NoopMessageDeliveryLoopMonitor implements MessageDeliveryLoopMonitor {

public NoopMessageDeliveryLoopMonitor() {
}

public void recordDeliveryAttempt(final UUID accountIdentifier, final byte deviceId, final UUID messageGuid, final String userAgent, final String context) {
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.whispersystems.textsecuregcm.limits;

import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;

public class RedisMessageDeliveryLoopMonitor implements MessageDeliveryLoopMonitor {

private final ClusterLuaScript getDeliveryAttemptsScript;

private static final Duration DELIVERY_ATTEMPTS_COUNTER_TTL = Duration.ofHours(1);
private static final int DELIVERY_LOOP_THRESHOLD = 5;

private static final Logger logger = LoggerFactory.getLogger(MessageDeliveryLoopMonitor.class);

public RedisMessageDeliveryLoopMonitor(final FaultTolerantRedisClusterClient rateLimitCluster) {
try {
getDeliveryAttemptsScript =
ClusterLuaScript.fromResource(rateLimitCluster, "lua/get_delivery_attempt_count.lua", ScriptOutputType.INTEGER);
} catch (final IOException e) {
throw new UncheckedIOException("Failed to load 'get delivery attempt count' script", e);
}
}

/**
* Records an attempt to deliver a message with the given GUID to the given account/device pair and returns the number
* of consecutive attempts to deliver the same message and logs a warning if the message appears to be in a delivery
* loop. This method is intended to detect cases where a message remains at the head of a device's queue after
* repeated attempts to deliver the message, and so the given message GUID should be the first message of a "page"
* sent to clients.
*
* @param accountIdentifier the identifier of the destination account
* @param deviceId the destination device's ID within the given account
* @param messageGuid the GUID of the message
* @param userAgent the User-Agent header supplied by the caller
* @param context a human-readable string identifying the mechanism of message delivery (e.g. "rest" or "websocket")
*/
public void recordDeliveryAttempt(final UUID accountIdentifier,
final byte deviceId,
final UUID messageGuid,
final String userAgent,
final String context) {

incrementDeliveryAttemptCount(accountIdentifier, deviceId, messageGuid)
.thenAccept(deliveryAttemptCount -> {
if (deliveryAttemptCount == DELIVERY_LOOP_THRESHOLD) {
logger.warn("Detected loop delivering message {} via {} to {}:{} ({})",
messageGuid, context, accountIdentifier, deviceId, userAgent);
}
});
}

@VisibleForTesting
CompletableFuture<Long> incrementDeliveryAttemptCount(final UUID accountIdentifier, final byte deviceId, final UUID messageGuid) {
final String firstMessageGuidKey = "firstMessageGuid::{" + accountIdentifier + ":" + deviceId + "}";
final String deliveryAttemptsKey = "firstMessageDeliveryAttempts::{" + accountIdentifier + ":" + deviceId + "}";

return getDeliveryAttemptsScript.executeAsync(
List.of(firstMessageGuidKey, deliveryAttemptsKey),
List.of(messageGuid.toString(), String.valueOf(DELIVERY_ATTEMPTS_COUNTER_TTL.toSeconds())))
.thenApply(result -> (long) result);
}

}
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package org.whispersystems.textsecuregcm.metrics;

import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;

import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

public class OpenWebSocketCounter {

private static final String WEBSOCKET_CLOSED_COUNTER_NAME = name(OpenWebSocketCounter.class, "websocketClosed");

private final Map<ClientPlatform, AtomicInteger> openWebsocketsByClientPlatform;
private final AtomicInteger openWebsocketsFromUnknownPlatforms;

Expand Down Expand Up @@ -81,6 +85,9 @@ public void countOpenWebSocket(final WebSocketSessionContext context) {
context.addWebsocketClosedListener((context1, statusCode, reason) -> {
sample.stop(durationTimer);
openWebSocketCounter.decrementAndGet();

Metrics.counter(WEBSOCKET_CLOSED_COUNTER_NAME, "status", String.valueOf(statusCode))
.increment();
});
}
}
Loading

0 comments on commit d4f672b

Please sign in to comment.