Skip to content

Commit

Permalink
[Java] Add a test for rcv-hwm position update when a packet contain…
Browse files Browse the repository at this point in the history
…s trailing padding frame + rename method that computes the target position offset.
  • Loading branch information
vyazelenko committed Jan 27, 2025
1 parent f33e0ca commit a6c9d7d
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 6 deletions.
14 changes: 8 additions & 6 deletions aeron-driver/src/main/java/io/aeron/driver/PublicationImage.java
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ int insertPacket(
}
else
{
proposedPosition = packetPosition + computeFullPacketLength(buffer, length);
proposedPosition = packetPosition + computeActualFrameLength(buffer, length);
}

if (!isFlowControlOverRun(proposedPosition))
Expand All @@ -642,12 +642,12 @@ int insertPacket(
{
final long nowNs = cachedNanoClock.nanoTime();
timeOfLastPacketNs = nowNs;
trackConnection(transportIndex, srcAddress, nowNs);
final ImageConnection imageConnection = trackConnection(transportIndex, srcAddress, nowNs);

if (isEndOfStream)
{
imageConnections[transportIndex].eosPosition = packetPosition;
imageConnections[transportIndex].isEos = true;
imageConnection.eosPosition = packetPosition;
imageConnection.isEos = true;

if (!this.isEndOfStream && isAllConnectedEos())
{
Expand Down Expand Up @@ -963,7 +963,7 @@ void stopStatusMessagesIfNotActive()
}
}

private static int computeFullPacketLength(final UnsafeBuffer buffer, final int packetLength)
private static int computeActualFrameLength(final UnsafeBuffer buffer, final int packetLength)
{
int offset = 0;
while (offset < packetLength)
Expand Down Expand Up @@ -1043,7 +1043,8 @@ private void cleanBufferTo(final long position)
}
}

private void trackConnection(final int transportIndex, final InetSocketAddress srcAddress, final long nowNs)
private ImageConnection trackConnection(
final int transportIndex, final InetSocketAddress srcAddress, final long nowNs)
{
imageConnections = ArrayUtil.ensureCapacity(imageConnections, transportIndex + 1);
ImageConnection imageConnection = imageConnections[transportIndex];
Expand All @@ -1056,6 +1057,7 @@ private void trackConnection(final int transportIndex, final InetSocketAddress s

imageConnection.timeOfLastActivityNs = nowNs;
imageConnection.timeOfLastFrameNs = nowNs;
return imageConnection;
}

private boolean isAllConnectedEos()
Expand Down
249 changes: 249 additions & 0 deletions aeron-driver/src/test/java/io/aeron/driver/PublicationImageTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
* Copyright 2014-2025 Real Logic Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.aeron.driver;

import io.aeron.ChannelUri;
import io.aeron.driver.buffer.RawLog;
import io.aeron.driver.media.ReceiveChannelEndpoint;
import io.aeron.driver.media.UdpChannel;
import io.aeron.driver.status.ReceiverHwm;
import io.aeron.driver.status.ReceiverPos;
import io.aeron.driver.status.SystemCounterDescriptor;
import io.aeron.driver.status.SystemCounters;
import io.aeron.logbuffer.FrameDescriptor;
import io.aeron.protocol.DataHeaderFlyweight;
import org.agrona.BitUtil;
import org.agrona.ExpandableArrayBuffer;
import org.agrona.concurrent.CachedEpochClock;
import org.agrona.concurrent.CachedNanoClock;
import org.agrona.concurrent.UnsafeBuffer;
import org.agrona.concurrent.status.AtomicCounter;
import org.agrona.concurrent.status.CountersManager;
import org.agrona.concurrent.status.Position;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

import static io.aeron.logbuffer.LogBufferDescriptor.*;
import static io.aeron.protocol.DataHeaderFlyweight.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class PublicationImageTest
{
private static final int TERM_LENGTH = 64 * 1024;
private static final int INITIAL_WINDOW_LENGTH = 128 * 1024;
private static final int MAX_WINDOW_LENGHT = 1024 * 1024;
private static final long CORRELATION_ID = 42;
private static final int TRANSPORT_INDEX = 3;
private static final int SESSION_ID = 888;
private static final int STREAM_ID = 101010;
private static final int INITIAL_TERM_ID = -444666;
private static final int ACTIVE_TERM_ID = INITIAL_TERM_ID + 111;
private static final int TERM_OFFSET = TERM_LENGTH - TERM_LENGTH / 4;
private static final short FLAGS = FrameDescriptor.UNFRAGMENTED;
private static final String SOURCE_IDENTITY = "aeron:udp?endpoint=localhost:5555";
private final MediaDriver.Context ctx = new MediaDriver.Context();
private final ReceiveChannelEndpoint receiveChannelEndpoint = mock(ReceiveChannelEndpoint.class);
private final InetSocketAddress controlAddress = mock(InetSocketAddress.class);
private final RawLog rawLog = mock(RawLog.class);
private final FeedbackDelayGenerator feedbackDelayGenerator = mock(FeedbackDelayGenerator.class);
private final CongestionControl congestionControl = mock(CongestionControl.class);
private final CachedEpochClock epochClock = new CachedEpochClock();
private final CachedNanoClock nanoClock = new CachedNanoClock();
private final UnsafeBuffer buffer = new UnsafeBuffer(new byte[1024]);
private final CountersManager countersManager = new CountersManager(
new UnsafeBuffer(ByteBuffer.allocateDirect(256 * 1024)),
new UnsafeBuffer(ByteBuffer.allocateDirect(64 * 1024)),
StandardCharsets.US_ASCII);
private final DataHeaderFlyweight headerFlyweight = new DataHeaderFlyweight();
private Position hwmPosition;
private Position rcvPosition;
private PublicationImage image;

@BeforeEach
void before()
{
epochClock.update(TimeUnit.HOURS.toMillis(1));
nanoClock.update(TimeUnit.HOURS.toNanos(1));
ctx
.receiverCachedNanoClock(nanoClock)
.nanoClock(nanoClock)
.epochClock(epochClock)
.imageLivenessTimeoutNs(TimeUnit.SECONDS.toNanos(10))
.untetheredWindowLimitTimeoutNs(TimeUnit.SECONDS.toNanos(1))
.untetheredRestingTimeoutNs(TimeUnit.SECONDS.toNanos(1))
.statusMessageTimeoutNs(TimeUnit.MILLISECONDS.toNanos(150))
.systemCounters(new SystemCounters(countersManager));

final String channel = "aeron:udp?endpoint=localhost:5555";
final ChannelUri channelUri = ChannelUri.parse(channel);
final UdpChannel udpChannel = mock(UdpChannel.class);
when(udpChannel.channelUri()).thenReturn(channelUri);
when(receiveChannelEndpoint.subscriptionUdpChannel()).thenReturn(udpChannel);

final SubscriptionLink subscriptionLink1 = mock(SubscriptionLink.class);
when(subscriptionLink1.isReliable()).thenReturn(true);
when(subscriptionLink1.isTether()).thenReturn(true);
final SubscriberPosition subscriberPosition1 = mock(SubscriberPosition.class);
when(subscriberPosition1.subscription()).thenReturn(subscriptionLink1);
final SubscriptionLink subscriptionLink2 = mock(SubscriptionLink.class);
when(subscriptionLink1.isReliable()).thenReturn(false);
when(subscriptionLink1.isTether()).thenReturn(false);
final SubscriberPosition subscriberPosition2 = mock(SubscriberPosition.class);
when(subscriberPosition2.subscription()).thenReturn(subscriptionLink2);
final ArrayList<SubscriberPosition> subscriberPositions = new ArrayList<>();
subscriberPositions.add(subscriberPosition1);
subscriberPositions.add(subscriberPosition2);

final UnsafeBuffer[] termBuffers = new UnsafeBuffer[PARTITION_COUNT];
for (int i = 0; i < termBuffers.length; i++)
{
termBuffers[i] = new UnsafeBuffer(new byte[TERM_LENGTH]);
}
when(rawLog.termBuffers()).thenReturn(termBuffers);
when(rawLog.metaData()).thenReturn(new UnsafeBuffer(new byte[LOG_META_DATA_LENGTH]));
when(rawLog.termLength()).thenReturn(TERM_LENGTH);

when(congestionControl.initialWindowLength()).thenReturn(INITIAL_WINDOW_LENGTH);
when(congestionControl.maxWindowLength()).thenReturn(MAX_WINDOW_LENGHT);

final long registrationId = 73249234983274L;
final ExpandableArrayBuffer tempBuffer = new ExpandableArrayBuffer();
hwmPosition = ReceiverHwm.allocate(tempBuffer, countersManager, registrationId, SESSION_ID, STREAM_ID, channel);
rcvPosition = ReceiverPos.allocate(
tempBuffer, countersManager, registrationId, SESSION_ID, STREAM_ID, channel);

image = new PublicationImage(
CORRELATION_ID,
ctx,
receiveChannelEndpoint,
TRANSPORT_INDEX,
controlAddress,
SESSION_ID,
STREAM_ID,
INITIAL_TERM_ID,
ACTIVE_TERM_ID,
TERM_OFFSET,
FLAGS,
rawLog,
feedbackDelayGenerator,
subscriberPositions,
hwmPosition,
rcvPosition,
SOURCE_IDENTITY,
congestionControl);

final long position = computePosition(
ACTIVE_TERM_ID, TERM_OFFSET, positionBitsToShift(TERM_LENGTH), INITIAL_TERM_ID);
assertEquals(position, hwmPosition.get());
assertEquals(position, rcvPosition.get());

ThreadLocalRandom.current().nextBytes(buffer.byteArray());
}

@Test
void shouldTakeIntoAccountTrailingPaddingFrameWhenIncrementingHighWaterMarkPosition()
{
final int totalLength = 512;
final int packetLength = 288;
final int termId = ACTIVE_TERM_ID;
final int termOffset = TERM_LENGTH - totalLength;
int offset = 0;
offset += writeFrame(offset, termOffset, termId, 65, BEGIN_AND_END_FLAGS, HDR_TYPE_DATA, 65);
offset += writeFrame(offset, termOffset + offset, termId, 96, BEGIN_AND_END_FLAGS, HDR_TYPE_DATA, 96);
offset += writeFrame(offset, termOffset + offset, termId, 224, BEGIN_AND_END_FLAGS, HDR_TYPE_PAD, 0x888AA888);
assertEquals(totalLength, offset);
final InetSocketAddress srcAddress = mock(InetSocketAddress.class);

final int bytes = image.insertPacket(termId, termOffset, buffer, packetLength, TRANSPORT_INDEX, srcAddress);

assertEquals(packetLength, bytes);
final int positionBitsToShift = positionBitsToShift(TERM_LENGTH);
final long packetPosition = computePosition(termId, termOffset, positionBitsToShift, INITIAL_TERM_ID);
assertEquals(packetPosition + totalLength, hwmPosition.get());
final UnsafeBuffer activeTermBuffer =
rawLog.termBuffers()[indexByPosition(packetPosition, positionBitsToShift)];
for (int i = 0; i < packetLength; i++)
{
assertEquals(buffer.getByte(i), activeTermBuffer.getByte(termOffset + i));
}
for (int i = packetLength; i < totalLength; i++)
{
assertEquals(0, activeTermBuffer.getByte(termOffset + i));
}
}

@Test
void shouldAdvanceHighWaterMarkPositionOnHeartbeat()
{
final int termId = ACTIVE_TERM_ID;
final int termOffset = TERM_OFFSET + 1024;
writeFrame(0, termOffset, termId, 0, BEGIN_AND_END_FLAGS, HDR_TYPE_DATA, -1);
FrameDescriptor.frameLengthOrdered(buffer, 0, 0);
final InetSocketAddress srcAddress = mock(InetSocketAddress.class);
final int packetLength = HEADER_LENGTH;
final AtomicCounter heartBeatsCounter = ctx.systemCounters().get(SystemCounterDescriptor.HEARTBEATS_RECEIVED);
final long oldHeartBeatCount = heartBeatsCounter.getWeak();

final int bytes = image.insertPacket(termId, termOffset, buffer, packetLength, TRANSPORT_INDEX, srcAddress);

assertEquals(packetLength, bytes);
final int positionBitsToShift = positionBitsToShift(TERM_LENGTH);
final long packetPosition = computePosition(termId, termOffset, positionBitsToShift, INITIAL_TERM_ID);
assertEquals(packetPosition, hwmPosition.get());
assertEquals(oldHeartBeatCount + 1, heartBeatsCounter.getWeak());
final UnsafeBuffer activeTermBuffer =
rawLog.termBuffers()[indexByPosition(packetPosition, positionBitsToShift)];
for (int i = 0; i < packetLength; i++)
{
assertEquals(0, activeTermBuffer.getByte(termOffset + i));
}
}

private int writeFrame(
final int offset,
final int termOffset,
final int termId,
final int length,
final short flags,
final int type,
final int reservedValue)
{
final int frameLength = length + HEADER_LENGTH;
headerFlyweight.wrap(buffer, offset, frameLength);
headerFlyweight
.frameLength(frameLength)
.version(CURRENT_VERSION)
.flags(flags)
.headerType(type);
headerFlyweight
.termOffset(termOffset)
.sessionId(SESSION_ID)
.streamId(STREAM_ID)
.termId(termId)
.reservedValue(reservedValue);

return BitUtil.align(frameLength, FrameDescriptor.FRAME_ALIGNMENT);
}
}

0 comments on commit a6c9d7d

Please sign in to comment.