Skip to content

Commit

Permalink
Pass UnionShape for union type discrimination (#3984)
Browse files Browse the repository at this point in the history
UnionShape needs to be passed so that the customization code can detect
which Union type is being generated.

Co-authored-by: Fahad Zubair <fahadzub@amazon.com>
  • Loading branch information
drganjoo and Fahad Zubair authored Jan 29, 2025
1 parent 3d801c4 commit 0a63d5b
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ abstract class CborParserCustomization : NamedCustomization<CborParserSection>()
* @param defaultContext The default discrimination context containing decoder symbol and discriminator method.
* @return UnionVariantDiscriminatorContext that defines how to discriminate union variants.
*/
open fun getUnionVariantDiscriminator(defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext) =
defaultContext
open fun getUnionVariantDiscriminator(
unionShape: UnionShape,
defaultContext: CborParserGenerator.UnionVariantDiscriminatorContext,
) = defaultContext
}

class CborParserGenerator(
Expand Down Expand Up @@ -330,7 +332,7 @@ class CborParserGenerator(
val returnSymbolToParse = returnSymbolToParse(shape)
// Get actual decoder type to use and the discriminating function to call to extract
// the variant of the union that has been encoded in the data.
val discriminatorContext = getUnionDiscriminatorContext("Decoder", "decoder.str()?.as_ref()")
val discriminatorContext = getUnionDiscriminatorContext(shape, "Decoder", "decoder.str()?.as_ref()")

rustBlockTemplate(
"""
Expand Down Expand Up @@ -394,6 +396,7 @@ class CborParserGenerator(
}

private fun getUnionDiscriminatorContext(
unionShape: UnionShape,
decoderType: String,
callMethod: String,
): UnionVariantDiscriminatorContext {
Expand All @@ -403,7 +406,7 @@ class CborParserGenerator(
writable { rustTemplate(callMethod) },
)
return customizations.fold(defaultUnionPairContext) { context, customization ->
customization.getUnionVariantDiscriminator(context)
customization.getUnionVariantDiscriminator(unionShape, context)
}
}

Expand Down

0 comments on commit 0a63d5b

Please sign in to comment.