Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate code gen to use protocol list #3320

Merged
merged 1 commit into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class C2jMetadata {
private String signingName;
private String targetPrefix;
private String protocol;
private List<String> protocols;
private String clientProjectName;
private String clientClassNamePrefix;
private String uid;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,25 @@

package com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration;

import com.google.common.collect.ImmutableList;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;

import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

@Data
public class Metadata {
private static List<String> supportedProtocols = ImmutableList.of(
"json",
"rest-json",
"rest-xml",
"query",
"ec2"
);

private String apiVersion;
private String concatAPIVersion;
private String endpointPrefix;
Expand All @@ -23,6 +35,7 @@ public class Metadata {
private String signingName;
private String targetPrefix;
private String protocol;
private List<String> protocols;
private String projectName;
private String classNamePrefix;
private String acceptHeader;
Expand All @@ -47,4 +60,19 @@ public class Metadata {

// Priority-ordered list of auth types present on the service model
private List<String> auth;

public String findFirstSupportedProtocol() {
// we use application-json for api-gateway
if ("application-json".equals(protocol)) {
return protocol;
}

if (Objects.isNull(protocols) || protocols.isEmpty()) {
return protocol;
}

return protocols.stream().filter(supportedProtocols::contains)
.min(Comparator.comparingInt(protocolName -> supportedProtocols.indexOf(protocolName)))
.orElseThrow(() -> new RuntimeException(String.format("No supported protocol found for %s", serviceFullName)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public static String computeXmlConversionMethodName(Shape shape) {
}

public static String computeRequestContentType(Metadata metadata) {
String protocolAndVersion = metadata.getProtocol();
String protocolAndVersion = metadata.findFirstSupportedProtocol();

if(metadata.getJsonVersion() != null) {
protocolAndVersion += metadata.getJsonVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public ByteArrayOutputStream generateSourceFromC2jModel(C2jServiceModel c2jModel

spec.setVersion(serviceModel.getMetadata().getApiVersion());

String protocol = serviceModel.getMetadata().getProtocol();
String protocol = serviceModel.getMetadata().findFirstSupportedProtocol();
ClientGenerator clientGenerator = ServiceGeneratorConfig.findGenerator(spec, protocol);

//use serviceName and version to convert the json over.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ public ServiceModel convert() {
serviceModel.getMetadata().setEndpointOperationName(endpointOperationName);

// add protocol check. only for json, query protocols
if (serviceModel.getMetadata().getProtocol().equals("json")) {
final String protocol = serviceModel.getMetadata().findFirstSupportedProtocol();

if ("json".equals(protocol)) {
serviceModel.getMetadata().setAwsQueryCompatible(
c2jServiceModel.getMetadata().getAwsQueryCompatible() != null);
} else {
Expand Down Expand Up @@ -258,6 +260,7 @@ Metadata convertMetadata() {
metadata.setEndpointPrefix(c2jMetadata.getEndpointPrefix());
metadata.setProtocol(c2jMetadata.getProtocol());
}
metadata.setProtocols(c2jMetadata.getProtocols());
metadata.setNamespace(c2jMetadata.getServiceAbbreviation());
metadata.setServiceFullName(c2jMetadata.getServiceFullName());
metadata.setSignatureVersion(c2jMetadata.getSignatureVersion());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#set($metadata = $serviceModel.metadata)
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($payloadType = ${CppViewHelper.computeServicePayloadType($metadata.protocol)})
#set($payloadType = ${CppViewHelper.computeServicePayloadType($metadata.findFirstSupportedProtocol())})
#set($nonCoreServiceErrors = $serviceModel.getNonCoreServiceErrors())
\#include <aws/core/client/AWSError.h>
\#include <aws/core/utils/HashingUtils.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/json/ModelClassMembersJsonizeSource.vm")
return payload.View().WriteReadable();
## for json protocol
#elseif($metadata.protocol.equals("json"))
#elseif($metadata.findFirstSupportedProtocol().equals("json"))
return "{}";
## for rest-json protocol
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#set($awsProjectProtocolTestSrc = "AWS_" + ${projectNameCaps.replace("-", "_")} +"_SRC")
#set($awsProjectProtocolTestSrcVar = "${" + $awsProjectProtocolTestSrc +"}")
add_project($projectName
"Tests for the protocol $serviceModel.metadata.protocol of AWS C++ SDK"
"Tests for the protocol $serviceModel.metadata.findFirstSupportedProtocol() of AWS C++ SDK"
testing-resources
aws-cpp-sdk-${serviceModel.metadata.projectName}
aws-cpp-sdk-core)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($spaces = " ")
#end
#if($member.value.shape.list)
#if($metadata.protocol != "ec2")
#if($metadata.findFirstSupportedProtocol() != "ec2")
#set($spaces = " ")
if (${memberVarName}.empty())
{
Expand All @@ -64,7 +64,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($location = $member.value.shape.listMember.queryName)
#elseif($member.value.shape.listMember.locationName)
#set($location = $member.value.shape.listMember.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#else
Expand All @@ -75,10 +75,10 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#set($location = $member.value.queryName)
#elseif($member.value.locationName)
#set($location = $member.value.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#elseif($metadata.protocol == "ec2")
#elseif($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $member.key)
#elseif($member.value.shape.listMember.locationName)
#set($location = $member.key + "." + $member.value.shape.listMember.locationName)
Expand Down Expand Up @@ -106,7 +106,7 @@ Aws::String ${typeInfo.className}::SerializePayload() const
#end
${spaces} ${varName}Count++;
${spaces}}
#if($metadata.protocol != "ec2")
#if($metadata.findFirstSupportedProtocol() != "ec2")
${spaces}}
#end
#elseif($member.value.shape.map)##--#if($member.value.shape.list)
Expand Down Expand Up @@ -171,7 +171,7 @@ ${spaces}}
#set($location = $member.value.queryName)
#elseif($member.value.locationName)
#set($location = $member.value.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ${typeInfo.className}& ${typeInfo.className}::operator =(const Aws::AmazonWebSer
}

if (!rootNode.IsNull()) {
#if ($metadata.protocol == "ec2" )
#if ($metadata.findFirstSupportedProtocol() == "ec2" )
XmlNode requestIdNode = rootNode.FirstChild("requestId");
if (!requestIdNode.IsNull())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void ${typeInfo.className}::OutputToStream(Aws::OStream& oStream, const char* lo
${spaces}unsigned ${lowerCaseVarName}Idx = 1;
${spaces}for(auto& item : ${memberVarName})
${spaces}{
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#if($member.queryName)
#set($location = $member.queryName)
#elseif($member.locationName)
Expand Down Expand Up @@ -246,10 +246,10 @@ void ${typeInfo.className}::OutputToStream(Aws::OStream& oStream, const char* lo
#set($location = $member.shape.listMember.queryName)
#elseif($member.shape.listMember.locationName)
#set($location = $member.shape.listMember.locationName)
#if($metadata.protocol == "ec2")
#if($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $CppViewHelper.capitalizeFirstChar($location))
#end
#elseif($metadata.protocol == "ec2")
#elseif($metadata.findFirstSupportedProtocol() == "ec2")
#set($location = $memberName)
#else
#set($location = $memberName + ".member")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ${rootNamespace}
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConfigTypeDeclarations.vm")
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConstructors.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")

/**
* Converts any request object to a presigned URL with the GET method, using region for the signer and a timeout of 15 minutes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($className = "${metadata.classNamePrefix}Client")
#if($serviceModel.metadata.protocol == "json" || $serviceModel.metadata.protocol == "rest-json" || $serviceModel.metadata.protocol == "application-json")
#if($serviceModel.metadata.findFirstSupportedProtocol() == "json" || $serviceModel.metadata.findFirstSupportedProtocol() == "rest-json" || $serviceModel.metadata.findFirstSupportedProtocol() == "application-json")
#set($serializer = "JsonOutcomeSerializer")
#set($serializerOutcome = "JsonOutcome")
#elseif($serviceModel.metadata.protocol == "rest-xml" || $serviceModel.metadata.protocol == "query")
#elseif($serviceModel.metadata.findFirstSupportedProtocol() == "rest-xml" || $serviceModel.metadata.findFirstSupportedProtocol() == "query")
#set($serializer = "XmlOutcomeSerializer")
#set($serializerOutcome = "XmlOutcome")
#end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#set($rootNamespace = $serviceModel.namespace)
#set($serviceNamespace = $metadata.namespace)
#set($className = "${metadata.classNamePrefix}Client")
#if($serviceModel.metadata.protocol == "json" || $serviceModel.metadata.protocol == "rest-json" || $serviceModel.metadata.protocol == "application-json")
#if($serviceModel.metadata.findFirstSupportedProtocol() == "json" || $serviceModel.metadata.findFirstSupportedProtocol() == "rest-json" || $serviceModel.metadata.findFirstSupportedProtocol() == "application-json")
#set($serializer = "JsonOutcomeSerializer")
#set($serializerOutcome = "JsonOutcome")
#elseif($serviceModel.metadata.protocol == "rest-xml" || $serviceModel.metadata.protocol == "query")
#elseif($serviceModel.metadata.findFirstSupportedProtocol() == "rest-xml" || $serviceModel.metadata.findFirstSupportedProtocol() == "query")
#set($serializer = "XmlOutcomeSerializer")
#set($serializerOutcome = "XmlOutcome")
#end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const char* ${className}::GetAllocationTag() {return ALLOCATION_TAG;}
#if(!${onlyGeneratedOperations})
#parseOverrideOrDefault( "ServiceClientSourceInit_template" "com/amazonaws/util/awsclientgenerator/velocity/cpp/smithy/SmithyClientSourceInit.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")
Aws::String ${className}::ConvertRequestToPresignedUrl(const AmazonSerializableWebServiceRequest& requestToConvert, const char* region) const
{
if (!m_endpointProvider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace ${serviceNamespace}
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConfigTypeDeclarations.vm")
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientHeaderConstructors.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")

/**
* Converts any request object to a presigned URL with the GET method, using region for the signer and a timeout of 15 minutes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ const char* ${className}::GetAllocationTag() {return ALLOCATION_TAG;}
#if(!${onlyGeneratedOperations})
#parseOverrideOrDefault( "ServiceClientSourceInit_template" "com/amazonaws/util/awsclientgenerator/velocity/cpp/ServiceClientSourceInit.vm")

#if($metadata.protocol == "query")
#if($metadata.findFirstSupportedProtocol() == "query")
Aws::String ${className}::ConvertRequestToPresignedUrl(const AmazonSerializableWebServiceRequest& requestToConvert, const char* region) const
{
if (!m_endpointProvider)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration;

import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;

public class MetadataTest {
@Test
public void shouldPreferProtocolOverProtocolsForApiGateway() {
Metadata metadata = new Metadata();
metadata.setProtocol("application-json");
metadata.setProtocols(ImmutableList.of("json"));
Assert.assertEquals("application-json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldPreferProtocolsOverProtocol() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of("json"));
Assert.assertEquals("json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldPreferBestFitProtocol() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of("rest-json", "json"));
Assert.assertEquals("json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldUseProtocolWhenProtocolsIsMissing() {
Metadata metadata = new Metadata();
metadata.setProtocol("rest-json");
metadata.setProtocols(ImmutableList.of());
Assert.assertEquals("rest-json", metadata.findFirstSupportedProtocol());
}

@Test
public void shouldThrowExeceptionWhenUnsupportedProtocol() {
Metadata metadata = new Metadata();
metadata.setServiceFullName("ServiceName");
metadata.setProtocols(ImmutableList.of("grpc"));
Assert.assertThrows("No supported protocol found for ServiceName", RuntimeException.class, metadata::findFirstSupportedProtocol);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void testMetadataConversion() {
assertEquals(c2jMetadata.getApiVersion(), metadata.getApiVersion());
assertEquals(c2jMetadata.getEndpointPrefix(), metadata.getEndpointPrefix());
assertEquals(c2jMetadata.getJsonVersion(), metadata.getJsonVersion());
assertEquals(c2jMetadata.getProtocol(), metadata.getProtocol());
assertEquals(c2jMetadata.getProtocol(), metadata.findFirstSupportedProtocol());
assertEquals("ServiceAbbr", metadata.getNamespace());
assertEquals(c2jMetadata.getServiceFullName(), metadata.getServiceFullName());
assertEquals(c2jMetadata.getSignatureVersion(), metadata.getSignatureVersion());
Expand All @@ -57,7 +57,7 @@ public void testMetadataConversion() {
assertEquals(c2jMetadata.getApiVersion(), metadata.getApiVersion());
assertEquals(c2jMetadata.getEndpointPrefix(), metadata.getEndpointPrefix());
assertEquals(c2jMetadata.getJsonVersion(), metadata.getJsonVersion());
assertEquals(c2jMetadata.getProtocol(), metadata.getProtocol());
assertEquals(c2jMetadata.getProtocol(), metadata.findFirstSupportedProtocol());
assertEquals("ServiceAbbr", metadata.getNamespace());
assertEquals(c2jMetadata.getServiceFullName(), metadata.getServiceFullName());
assertEquals(c2jMetadata.getSignatureVersion(), metadata.getSignatureVersion());
Expand All @@ -84,7 +84,7 @@ public void testMetadataConversationWithStandalonePackages() {
assertTrue(metadata.isStandalone());
assertTrue(metadata.isApigateway());
assertEquals("service-abbr.execute-api", metadata.getEndpointPrefix());
assertEquals("application-json", metadata.getProtocol());
assertEquals("application-json", metadata.findFirstSupportedProtocol());
}

@Test
Expand Down