diff --git a/plugin/build.gradle b/plugin/build.gradle index 8ea9e9b..54255d4 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -21,6 +21,7 @@ dependencies { implementation platform('com.google.cloud:libraries-bom:26.22.0') implementation 'com.google.cloud:google-cloud-pubsub' implementation 'io.airlift:stats:235' + implementation 'io.airlift:aircompressor:0.25' testImplementation 'org.hamcrest:hamcrest-core:2.2' } diff --git a/plugin/src/main/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoder.java b/plugin/src/main/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoder.java new file mode 100644 index 0000000..5c0b89e --- /dev/null +++ b/plugin/src/main/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoder.java @@ -0,0 +1,31 @@ +package dev.regadas.trino.pubsub.listener; + +import com.google.protobuf.Message; +import dev.regadas.trino.pubsub.listener.Encoder.MessageEncoder; +import io.airlift.compress.zstd.ZstdOutputStream; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.Objects; + +public class CompressingMessageEncoder implements MessageEncoder { + + private final MessageEncoder delegate; + + public CompressingMessageEncoder(MessageEncoder delegate) { + this.delegate = Objects.requireNonNull(delegate); + } + + @Override + public byte[] encode(Message value) throws Exception { + var uncompressedBytes = delegate.encode(value); + try (var in = new ByteArrayInputStream(uncompressedBytes); + var bao = new ByteArrayOutputStream()) { + // ZstdOutputStream compress and flushes on close, + // so we wrap it on its own try with resources + try (var zout = new ZstdOutputStream(bao)) { + in.transferTo(zout); + } + return bao.toByteArray(); + } + } +} diff --git a/plugin/src/main/java/dev/regadas/trino/pubsub/listener/pubsub/PubSubPublisher.java b/plugin/src/main/java/dev/regadas/trino/pubsub/listener/pubsub/PubSubPublisher.java index 8efc9ed..e30e638 100644 --- a/plugin/src/main/java/dev/regadas/trino/pubsub/listener/pubsub/PubSubPublisher.java +++ b/plugin/src/main/java/dev/regadas/trino/pubsub/listener/pubsub/PubSubPublisher.java @@ -13,6 +13,7 @@ import com.google.protobuf.Message; import com.google.pubsub.v1.PubsubMessage; import com.google.pubsub.v1.TopicName; +import dev.regadas.trino.pubsub.listener.CompressingMessageEncoder; import dev.regadas.trino.pubsub.listener.Encoder; import dev.regadas.trino.pubsub.listener.Encoder.Encoding; import dev.regadas.trino.pubsub.listener.Encoder.MessageEncoder; @@ -50,7 +51,7 @@ public static PubSubPublisher create( .setBatchingSettings(batchingSettings) .build(); - var encoder = MessageEncoder.create(encoding); + var encoder = new CompressingMessageEncoder(MessageEncoder.create(encoding)); return new PubSubPublisher(publisher, encoder); } diff --git a/plugin/src/test/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoderTest.java b/plugin/src/test/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoderTest.java new file mode 100644 index 0000000..8c15e07 --- /dev/null +++ b/plugin/src/test/java/dev/regadas/trino/pubsub/listener/CompressingMessageEncoderTest.java @@ -0,0 +1,63 @@ +package dev.regadas.trino.pubsub.listener; + +import static java.util.stream.Collectors.joining; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThan; + +import com.google.protobuf.Message; +import dev.regadas.trino.pubsub.listener.Encoder.MessageEncoder; +import dev.regadas.trino.pubsub.listener.proto.Test.TestMessage; +import io.airlift.compress.zstd.ZstdInputStream; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class CompressingMessageEncoderTest { + + private static final String TEXT = Stream.generate(() -> "a").limit(1000).collect(joining()); + private static final TestMessage MESSAGE = TestMessage.newBuilder().setText(TEXT).build(); + private static final ProtoMessageEncoder DELEGATE = new ProtoMessageEncoder(); + private CompressingMessageEncoder encoder; + + @BeforeEach + void setUp() { + encoder = new CompressingMessageEncoder(DELEGATE); + } + + @Test + void testEncodeActuallyCompress() throws Exception { + byte[] uncompressed = DELEGATE.encode(MESSAGE); + + byte[] compressed = encoder.encode(MESSAGE); + + assertThat(compressed.length, lessThan(uncompressed.length)); + } + + @Test + void testEncodeCompressionRoundTrip() throws Exception { + byte[] compressed = encoder.encode(MESSAGE); + + byte[] decompressed = decompress(compressed); + assertThat(TestMessage.parseFrom(decompressed), equalTo(MESSAGE)); + } + + public static byte[] decompress(byte[] uncompressedBytes) throws IOException { + try (var zin = new ZstdInputStream(new ByteArrayInputStream(uncompressedBytes)); + var bao = new ByteArrayOutputStream()) { + zin.transferTo(bao); + return bao.toByteArray(); + } + } + + static class ProtoMessageEncoder implements MessageEncoder { + + @Override + public byte[] encode(Message value) { + return value.toByteArray(); + } + } +} diff --git a/plugin/src/test/proto/test.proto b/plugin/src/test/proto/test.proto new file mode 100644 index 0000000..d8bbf52 --- /dev/null +++ b/plugin/src/test/proto/test.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; +package dev.regadas.trino.pubsub.listener.proto; + +message TestMessage { + string text = 1; +}