Skip to content

Commit

Permalink
migrate code gen to use protocol list
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Feb 28, 2025
1 parent 44dacf9 commit c9c6c26
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,24 @@

package com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration;

import com.google.common.collect.ImmutableList;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;

import java.util.Comparator;
import java.util.List;
import java.util.Map;

@Data
public class Metadata {
private static List<String> supportedProtocols = ImmutableList.of(
"json",
"rest-json",
"rest-xml",
"query",
"ec2"
);

private String apiVersion;
private String concatAPIVersion;
private String endpointPrefix;
Expand All @@ -23,6 +34,7 @@ public class Metadata {
private String signingName;
private String targetPrefix;
private String protocol;
private List<String> protocols;
private String projectName;
private String classNamePrefix;
private String acceptHeader;
Expand All @@ -47,4 +59,18 @@ public class Metadata {

// Priority-ordered list of auth types present on the service model
private List<String> auth;

public String findFirstSupportedProtocol() {
if ("api-gateway".equals(protocol)) {
return protocol;
}

if (protocols.isEmpty()) {
return protocol;
}

return protocols.stream().filter(supportedProtocols::contains)
.min(Comparator.comparingInt(protocolName -> supportedProtocols.indexOf(protocolName)))
.orElseThrow(() -> new RuntimeException(String.format("No supported protocol found for %s", serviceFullName)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public static String computeXmlConversionMethodName(Shape shape) {
}

public static String computeRequestContentType(Metadata metadata) {
String protocolAndVersion = metadata.getProtocol();
String protocolAndVersion = metadata.findFirstSupportedProtocol();

if(metadata.getJsonVersion() != null) {
protocolAndVersion += metadata.getJsonVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public ByteArrayOutputStream generateSourceFromC2jModel(C2jServiceModel c2jModel

spec.setVersion(serviceModel.getMetadata().getApiVersion());

String protocol = serviceModel.getMetadata().getProtocol();
String protocol = serviceModel.getMetadata().findFirstSupportedProtocol();
ClientGenerator clientGenerator = ServiceGeneratorConfig.findGenerator(spec, protocol);

//use serviceName and version to convert the json over.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ public ServiceModel convert() {
serviceModel.getMetadata().setEndpointOperationName(endpointOperationName);

// add protocol check. only for json, query protocols
if (serviceModel.getMetadata().getProtocol().equals("json")) {
final String protocol = serviceModel.getMetadata().findFirstSupportedProtocol();

if ("json".equals(protocol)) {
serviceModel.getMetadata().setAwsQueryCompatible(
c2jServiceModel.getMetadata().getAwsQueryCompatible() != null);
} else {
Expand Down Expand Up @@ -252,11 +254,11 @@ Metadata convertMetadata() {
metadata.setJsonVersion(c2jMetadata.getJsonVersion());
if("api-gateway".equalsIgnoreCase(c2jMetadata.getProtocol())) {
metadata.setEndpointPrefix(c2jMetadata.getEndpointPrefix() + ".execute-api");
metadata.setProtocol("application-json");
metadata.setProtocols(ImmutableList.of("application-json"));
metadata.setApigateway(true);
} else {
metadata.setEndpointPrefix(c2jMetadata.getEndpointPrefix());
metadata.setProtocol(c2jMetadata.getProtocol());
metadata.setProtocols(ImmutableList.of(c2jMetadata.getProtocol()));
}
metadata.setNamespace(c2jMetadata.getServiceAbbreviation());
metadata.setServiceFullName(c2jMetadata.getServiceFullName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#set($metadata = $serviceModel.metadata)
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($payloadType = ${CppViewHelper.computeServicePayloadType($metadata.protocol)})
#set($payloadType = ${CppViewHelper.computeServicePayloadType($metadata.findFirstSupportedProtocol())})
#set($nonCoreServiceErrors = $serviceModel.getNonCoreServiceErrors())
\#include <aws/core/client/AWSError.h>
\#include <aws/core/utils/HashingUtils.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/json/ModelClassMembersJsonizeSource.vm")
return payload.View().WriteReadable();
## for json protocol
#elseif($metadata.protocol.equals("json"))
#elseif($metadata.findFirstSupportedProtocol().equals("json"))
return "{}";
## for rest-json protocol
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#set($awsProjectProtocolTestSrc = "AWS_" + ${projectNameCaps.replace("-", "_")} +"_SRC")
#set($awsProjectProtocolTestSrcVar = "${" + $awsProjectProtocolTestSrc +"}")
add_project($projectName
"Tests for the protocol $serviceModel.metadata.protocol of AWS C++ SDK"
"Tests for the protocol $serviceModel.metadata.findFirstSupportedProtocol() of AWS C++ SDK"
testing-resources
aws-cpp-sdk-${serviceModel.metadata.projectName}
aws-cpp-sdk-core)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($spaces = " ")
#end
#if($member.value.shape.list)
#if($metadata.protocol != "ec2")
#if($metadata.findFirstSupportedProtocol() != "ec2")
#set($spaces = " ")
if (${memberVarName}.empty())
{
Expand All @@ -64,7 +64,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($location = $member.value.shape.listMember.queryName)
#elseif($member.value.shape.listMember.locationName)
#set($location = $member.value.shape.listMember.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#else
Expand All @@ -75,10 +75,10 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($location = $member.value.queryName)
#elseif($member.value.locationName)
#set($location = $member.value.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#elseif($metadata.protocol == "ec2")
#elseif($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $member.key)
#elseif($member.value.shape.listMember.locationName)
#set($location = $member.key + "." + $member.value.shape.listMember.locationName)
Expand Down Expand Up @@ -106,7 +106,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#end
${spaces} ${varName}Count++;
${spaces}}
#if($metadata.protocol != "ec2")
#if($metadata.findFirstSupportedProtocol() != "ec2")
${spaces}}
#end
#elseif($member.value.shape.map)##--#if($member.value.shape.list)
Expand Down Expand Up @@ -171,7 +171,7 @@ ${spaces}}
#set($location = $member.value.queryName)
#elseif($member.value.locationName)
#set($location = $member.value.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ${typeInfo.className}& ${typeInfo.className}::operator =(const Aws::AmazonWebSer
}

if (!rootNode.IsNull()) {
#if ($metadata.protocol == "ec2" )
#if ($metadata.findFirstSupportedProtocol() == "ec2" )
XmlNode requestIdNode = rootNode.FirstChild("requestId");
if (!requestIdNode.IsNull())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void ${typeInfo.className}::OutputToStream(Aws::OStream& oStream, const char* lo
${spaces}unsigned ${lowerCaseVarName}Idx = 1;
${spaces}for(auto& item : ${memberVarName})
${spaces}{
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#if($member.queryName)
#set($location = $member.queryName)
#elseif($member.locationName)
Expand Down Expand Up @@ -246,10 +246,10 @@ void ${typeInfo.className}::OutputToStream(Aws::OStream& oStream, const char* lo
#set($location = $member.shape.listMember.queryName)
#elseif($member.shape.listMember.locationName)
#set($location = $member.shape.listMember.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#elseif($metadata.protocol == "ec2")
#elseif($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $memberName)
#else
#set($location = $memberName + ".member")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ${rootNamespace}
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConfigTypeDeclarations.vm")
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConstructors.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")

/**
* Converts any request object to a presigned URL with the GET method, using region for the signer and a timeout of 15 minutes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($className = "${metadata.classNamePrefix}Client")
#if($serviceModel.metadata.protocol == "json" || $serviceModel.metadata.protocol == "rest-json" || $serviceModel.metadata.protocol == "application-json")
#if($serviceModel.metadata.findFirstSupportedProtocol() == "json" || $serviceModel.metadata.findFirstSupportedProtocol() == "rest-json" || $serviceModel.metadata.findFirstSupportedProtocol() == "application-json")
#set($serializer = "JsonOutcomeSerializer")
#set($serializerOutcome = "JsonOutcome")
#elseif($serviceModel.metadata.protocol == "rest-xml" || $serviceModel.metadata.protocol == "query")
#elseif($serviceModel.metadata.findFirstSupportedProtocol() == "rest-xml" || $serviceModel.metadata.findFirstSupportedProtocol() == "query")
#set($serializer = "XmlOutcomeSerializer")
#set($serializerOutcome = "XmlOutcome")
#end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($className = "${metadata.classNamePrefix}Client")
#if($serviceModel.metadata.protocol == "json" || $serviceModel.metadata.protocol == "rest-json" || $serviceModel.metadata.protocol == "application-json")
#if($serviceModel.metadata.findFirstSupportedProtocol() == "json" || $serviceModel.metadata.findFirstSupportedProtocol() == "rest-json" || $serviceModel.metadata.findFirstSupportedProtocol() == "application-json")
#set($serializer = "JsonOutcomeSerializer")
#set($serializerOutcome = "JsonOutcome")
#elseif($serviceModel.metadata.protocol == "rest-xml" || $serviceModel.metadata.protocol == "query")
#elseif($serviceModel.metadata.findFirstSupportedProtocol() == "rest-xml" || $serviceModel.metadata.findFirstSupportedProtocol() == "query")
#set($serializer = "XmlOutcomeSerializer")
#set($serializerOutcome = "XmlOutcome")
#end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const char* ${className}::GetAllocationTag() {return ALLOCATION_TAG;}
#if(!${onlyGeneratedOperations})
#parseOverrideOrDefault( "ServiceClientSourceInit_template" "com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyClientSourceInit.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")
Aws::String ${className}::ConvertRequestToPresignedUrl(const AmazonSerializableWebServiceRequest& requestToConvert, const char* region) const
{
if (!m_endpointProvider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ${serviceNamespace}
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConfigTypeDeclarations.vm")
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConstructors.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")

/**
* Converts any request object to a presigned URL with the GET method, using region for the signer and a timeout of 15 minutes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ const char* ${className}::GetAllocationTag() {return ALLOCATION_TAG;}
#if(!${onlyGeneratedOperations})
#parseOverrideOrDefault( "ServiceClientSourceInit_template" "com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientSourceInit.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")
Aws::String ${className}::ConvertRequestToPresignedUrl(const AmazonSerializableWebServiceRequest& requestToConvert, const char* region) const
{
if (!m_endpointProvider)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration;

import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;

public class MetadataTest {
@Test
public void shouldPreferProtocolOverProtocolsForApiGateway() {
Metadata metadata = new Metadata();
metadata.setProtocol("api-gateway");
metadata.setProtocols(ImmutableList.of("json"));
Assert.assertEquals("api-gateway", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldPreferProtocolsOverProtocol() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of("json"));
Assert.assertEquals("json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldPreferBestFitProtocol() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of("rest-json", "json"));
Assert.assertEquals("json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldUseProtocolWhenProtocolsIsMissing() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of());
Assert.assertEquals("rest-json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldThrowExeceptionWhenUnsupportedProtocol() {
Metadata metadata = new Metadata();
metadata.setServiceFullName("ServiceName");
metadata.setProtocols(ImmutableList.of("grpc"));
Assert.assertThrows("No supported protocol found for ServiceName", RuntimeException.class, metadata::findFirstSupportedProtocol);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void testMetadataConversion() {
assertEquals(c2jMetadata.getApiVersion(), metadata.getApiVersion());
assertEquals(c2jMetadata.getEndpointPrefix(), metadata.getEndpointPrefix());
assertEquals(c2jMetadata.getJsonVersion(), metadata.getJsonVersion());
assertEquals(c2jMetadata.getProtocol(), metadata.getProtocol());
assertEquals(c2jMetadata.getProtocol(), metadata.findFirstSupportedProtocol());
assertEquals("ServiceAbbr", metadata.getNamespace());
assertEquals(c2jMetadata.getServiceFullName(), metadata.getServiceFullName());
assertEquals(c2jMetadata.getSignatureVersion(), metadata.getSignatureVersion());
Expand All @@ -57,7 +57,7 @@ public void testMetadataConversion() {
assertEquals(c2jMetadata.getApiVersion(), metadata.getApiVersion());
assertEquals(c2jMetadata.getEndpointPrefix(), metadata.getEndpointPrefix());
assertEquals(c2jMetadata.getJsonVersion(), metadata.getJsonVersion());
assertEquals(c2jMetadata.getProtocol(), metadata.getProtocol());
assertEquals(c2jMetadata.getProtocol(), metadata.findFirstSupportedProtocol());
assertEquals("ServiceAbbr", metadata.getNamespace());
assertEquals(c2jMetadata.getServiceFullName(), metadata.getServiceFullName());
assertEquals(c2jMetadata.getSignatureVersion(), metadata.getSignatureVersion());
Expand All @@ -84,7 +84,7 @@ public void testMetadataConversationWithStandalonePackages() {
assertTrue(metadata.isStandalone());
assertTrue(metadata.isApigateway());
assertEquals("service-abbr.execute-api", metadata.getEndpointPrefix());
assertEquals("application-json", metadata.getProtocol());
assertEquals("application-json", metadata.findFirstSupportedProtocol());
}

@Test
Expand Down

0 comments on commit c9c6c26

Please sign in to comment.