Skip to content

Commit

Permalink
Request encoding user info (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Aug 7, 2023
1 parent 80ae30b commit b56a64f
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 129 deletions.
151 changes: 61 additions & 90 deletions Sources/SotoCodeGeneratorLib/AwsService+shapes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,26 @@ extension AwsService {
}
}
}
if serviceProtocolTrait is AwsProtocolsRestXmlTrait {
xmlNamespace = shape.trait(type: XmlNamespaceTrait.self)?.uri
}
var xmlRootNodeName: String?
// check streaming traits
if let payloadMember = payloadMember, let payload = model.shape(for: payloadMember.value.target) {
if serviceProtocolTrait is AwsProtocolsRestXmlTrait {
if let namespace = payload.trait(type: XmlNamespaceTrait.self)?.uri {
xmlNamespace = namespace
}
}
if isOutput {
if payload is BlobShape || payload.hasTrait(type: StreamingTrait.self) {
shapeOptions.append("rawPayload")
}
} else if isInput {
// set XML root node name.
if serviceProtocolTrait is AwsProtocolsRestXmlTrait {
xmlRootNodeName = payloadMember.key
}
// currently only support request streaming of blobs
if payload is BlobShape,
payload.hasTrait(type: StreamingTrait.self)
Expand All @@ -176,9 +189,6 @@ extension AwsService {
}
}
}
if serviceProtocolTrait is AwsProtocolsRestXmlTrait {
xmlNamespace = shape.trait(type: XmlNamespaceTrait.self)?.uri
}
let recursive = doesShapeHaveRecursiveOwnReference(shape, shapeId: shapeId)
let initParameters = contexts.members.compactMap {
!$0.deprecated ? InitParamContext(parameter: $0.parameter, type: $0.type, default: $0.default) : nil
Expand All @@ -189,28 +199,38 @@ extension AwsService {
} else {
object = recursive ? "final class" : "struct"
}
var decodeContext: DecodeContext?
if isOutput {
let isResponse = shape.hasTrait(type: SotoResponseShapeTrait.self)
let hasCustomDecode = contexts.members.first { $0.decoding.fromCodable == nil } != nil
let hasNonDecodableElements = contexts.members.first {
$0.decoding.fromHeader != nil || $0.decoding.fromStatusCode != nil || $0.decoding.fromRawPayload == true || $0.decoding.fromEventStream == true
} != nil
decodeContext = .init(
requiresResponse: hasNonDecodableElements && isResponse,
requiresEvent: hasNonDecodableElements && !isResponse,
requiresDecodeInit: hasCustomDecode

var codingContext: ShapeCodingContext?
let isRootShape = shape.hasTrait(type: SotoResponseShapeTrait.self) || shape.hasTrait(type: SotoRequestShapeTrait.self)
// Has elements that require some form of custom encoding/decoding
let hasCustomCodableElements = contexts.members.contains { $0.memberCoding.isCodable == false }
if hasCustomCodableElements || typeIsUnion {
// Has elements that require a custom container
let hasNonDecodableElements = contexts.members.contains {
$0.memberCoding.isCodable == false && $0.memberCoding.isPayload == false
}
let singleValueContainer = contexts.members.contains {
$0.memberCoding.isPayload == true
}
// when setting values here. I am assuming that non root shapes must be events and require
// an event container instead of a request/response container. Also I am not outputting
// an encode for events as I don't support encoding tham at the moment
codingContext = ShapeCodingContext(
requiresResponse: isRootShape && hasNonDecodableElements,
requiresEvent: !isRootShape && hasNonDecodableElements,
requiresDecodeInit: (hasCustomCodableElements || typeIsUnion) && isOutput,
requiresEncode: ((hasCustomCodableElements && isRootShape) || typeIsUnion) && isInput,
singleValueContainer: singleValueContainer
)
}
return StructureContext(
object: object,
name: shapeName.toSwiftClassCase(),
shapeProtocol: shapeProtocol,
payload: isInput ? payloadMember?.key.toSwiftLabelCase() : nil,
options: shapeOptions.count > 0 ? shapeOptions.map { ".\($0)" }.joined(separator: ", ") : nil,
namespace: xmlNamespace,
isEncodable: isInput,
decode: decodeContext,
xmlRootNodeName: xmlRootNodeName,
shapeCoding: codingContext,
encoding: contexts.encoding,
members: contexts.members,
initParameters: initParameters,
Expand Down Expand Up @@ -258,16 +278,6 @@ extension AwsService {
) {
contexts.codingKeys.append(codingKeyContext)
}
// member encoding context. We don't need this for response objects as a custom init(from:) as setup for these
if !shape.hasTrait(type: SotoResponseShapeTrait.self) {
let memberEncodingContext = self.generateMemberEncodingContext(
member.value,
name: member.key,
isOutputShape: isOutputShape,
isPropertyWrapper: memberContext.propertyWrapper != nil && isInputShape
)
contexts.awsShapeMembers += memberEncodingContext
}
// validation context
if isInputShape {
if let validationContext = generateValidationContext(member.value, name: member.key) {
Expand Down Expand Up @@ -344,28 +354,33 @@ extension AwsService {
let propertyWrapper = self.generatePropertyWrapper(member, name: name, optional: optional)
let type = member.output(model)

let memberDecodeContext: MemberDecodeContext
var memberCodableContext: MemberCodableContext
if let headerTrait = member.trait(type: HttpHeaderTrait.self) {
memberDecodeContext = .init(fromHeader: headerTrait.value, decodeType: type)
memberCodableContext = .init(inHeader: headerTrait.value, codableType: type)
} else if let headerTrait = member.trait(type: HttpPrefixHeadersTrait.self) {
memberDecodeContext = .init(fromHeader: headerTrait.value, decodeType: type)
} else if member.hasTrait(type: HttpResponseCodeTrait.self) {
memberDecodeContext = .init(fromStatusCode: true, decodeType: type)
} else if targetShape.hasTrait(type: StreamingTrait.self) {
if targetShape is BlobShape {
memberDecodeContext = .init(fromRawPayload: true, decodeType: type)
} else {
memberDecodeContext = .init(fromEventStream: true, decodeType: type)
}
} else if member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self) {
if targetShape is BlobShape {
memberDecodeContext = .init(fromRawPayload: true, decodeType: type)
} else {
memberDecodeContext = .init(fromPayload: true, decodeType: type)
}
memberCodableContext = .init(inHeader: headerTrait.value, codableType: type)
} else if member.hasTrait(type: HttpResponseCodeTrait.self), isOutputShape {
memberCodableContext = .init(isStatusCode: true, codableType: type)
} else if let queryTrait = member.trait(type: HttpQueryTrait.self), !isOutputShape {
memberCodableContext = .init(inQuery: queryTrait.value, codableType: type)
} else if member.hasTrait(type: HttpQueryParamsTrait.self), !isOutputShape {
memberCodableContext = .init(areQueryParams: true, codableType: type)
} else if member.hasTrait(type: HttpLabelTrait.self), !isOutputShape {
let aliasTrait = member.trait(named: serviceProtocolTrait.nameTrait.staticName) as? AliasTrait
memberCodableContext = .init(inURI: aliasTrait?.alias ?? name, codableType: type)
} else if member.hasTrait(type: HttpPayloadTrait.self) ||
member.hasTrait(type: EventPayloadTrait.self) ||
targetShape.hasTrait(type: StreamingTrait.self)
{
memberCodableContext = .init(isPayload: true, codableType: type)
} else {
// Codable needs to decode property wrapper if it exists
memberDecodeContext = .init(fromCodable: true, decodeType: propertyWrapper ?? type)
memberCodableContext = .init(isCodable: true, codableType: propertyWrapper ?? type)
}
if member.hasTrait(type: HostLabelTrait.self), !isOutputShape {
let aliasTrait = member.trait(named: serviceProtocolTrait.nameTrait.staticName) as? AliasTrait
memberCodableContext.inHostPrefix = aliasTrait?.alias ?? name
memberCodableContext.isCodable = false
}
return MemberContext(
variable: name.toSwiftVariableCase(),
Expand All @@ -377,54 +392,10 @@ extension AwsService {
comment: processMemberDocs(from: member),
deprecated: deprecated,
duplicate: false, // TODO: NEED to catch this
decoding: memberDecodeContext
memberCoding: memberCodableContext
)
}

func generateMemberEncodingContext(_ member: MemberShape, name: String, isOutputShape: Bool, isPropertyWrapper: Bool) -> [MemberEncodingContext] {
var memberEncoding: [MemberEncodingContext] = []
// if header
if let headerTrait = member.trait(type: HttpHeaderTrait.self) {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".header(\"\(headerTrait.value)\")"))
// if prefix header
} else if let headerPrefixTrait = member.trait(type: HttpPrefixHeadersTrait.self) {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".headerPrefix(\"\(headerPrefixTrait.value)\")"))
// if query string
} else if let queryTrait = member.trait(type: HttpQueryTrait.self) {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".querystring(\"\(queryTrait.value)\")"))
// if part of URL
} else if member.hasTrait(type: HttpLabelTrait.self) {
let labelName = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
let aliasTrait = member.trait(named: serviceProtocolTrait.nameTrait.staticName) as? AliasTrait
memberEncoding.append(.init(name: labelName, location: ".uri(\"\(aliasTrait?.alias ?? name)\")"))
// if response status code
} else if member.hasTrait(type: HttpResponseCodeTrait.self) {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".statusCode"))
// if payload and not a blob or shape is an output shape
} else if member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self),
!(model.shape(for: member.target) is BlobShape) || isOutputShape
{
let aliasTrait = member.traits?.first(where: { $0 is AliasTrait }) as? AliasTrait
let payloadName = aliasTrait?.alias ?? name
let swiftLabelName = name.toSwiftLabelCase()
if swiftLabelName != payloadName {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".body(\"\(payloadName)\")"))
}
}

if member.hasTrait(type: HostLabelTrait.self) {
let labelName = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
let aliasTrait = member.trait(named: serviceProtocolTrait.nameTrait.staticName) as? AliasTrait
memberEncoding.append(.init(name: labelName, location: ".hostname(\"\(aliasTrait?.alias ?? name)\")"))
}
return memberEncoding
}

func generateCodingKeyContext(
_ member: MemberShape,
targetShape: Shape,
Expand Down
55 changes: 39 additions & 16 deletions Sources/SotoCodeGeneratorLib/AwsService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,6 @@ struct AwsService {
if usedInOutput {
shapeProtocol += " & AWSDecodableShape"
}
if hasPayload {
shapeProtocol += " & AWSShapeWithPayload"
}
} else if usedInOutput {
shapeProtocol = "AWSDecodableShape"
} else {
Expand All @@ -618,6 +615,7 @@ struct AwsService {
return !(member.hasTrait(type: HttpHeaderTrait.self) ||
member.hasTrait(type: HttpPrefixHeadersTrait.self) ||
(member.hasTrait(type: HttpQueryTrait.self) && !isOutputShape) ||
member.hasTrait(type: HttpQueryParamsTrait.self) ||
member.hasTrait(type: HttpLabelTrait.self) ||
member.hasTrait(type: HttpResponseCodeTrait.self))
}
Expand Down Expand Up @@ -694,14 +692,38 @@ extension AwsService {
let value: String
}

struct MemberDecodeContext {
var fromHeader: String?
var fromPayload: Bool?
var fromRawPayload: Bool?
var fromEventStream: Bool?
var fromCodable: Bool?
var fromStatusCode: Bool?
var decodeType: String
struct MemberCodableContext {
internal init(
inHeader: String? = nil,
inQuery: String? = nil,
inURI: String? = nil,
inHostPrefix: String? = nil,
areQueryParams: Bool = false,
isPayload: Bool = false,
isCodable: Bool = false,
isStatusCode: Bool = false,
codableType: String
) {
self.inHeader = inHeader
self.inQuery = inQuery
self.inURI = inURI
self.inHostPrefix = inHostPrefix
self.areQueryParams = areQueryParams
self.isPayload = isPayload
self.isCodable = isCodable
self.isStatusCode = isStatusCode
self.codableType = codableType
}

var inHeader: String?
var inQuery: String?
var inURI: String?
var inHostPrefix: String?
var areQueryParams: Bool
var isPayload: Bool
var isCodable: Bool
var isStatusCode: Bool
var codableType: String
}

struct MemberContext {
Expand All @@ -714,7 +736,7 @@ extension AwsService {
let comment: [String.SubSequence]
let deprecated: Bool
var duplicate: Bool
var decoding: MemberDecodeContext
var memberCoding: MemberCodableContext
}

struct InitParamContext {
Expand Down Expand Up @@ -779,21 +801,22 @@ extension AwsService {
var endpoints: [(region: String, hostname: String)] = []
}

struct DecodeContext {
struct ShapeCodingContext {
let requiresResponse: Bool
let requiresEvent: Bool
let requiresDecodeInit: Bool
let requiresEncode: Bool
let singleValueContainer: Bool
}

struct StructureContext {
let object: String
let name: String
let shapeProtocol: String
let payload: String?
var options: String?
let namespace: String?
let isEncodable: Bool
let decode: DecodeContext?
let xmlRootNodeName: String?
let shapeCoding: ShapeCodingContext?
let encoding: [EncodingPropertiesContext]
let members: [MemberContext]
let initParameters: [InitParamContext]
Expand Down
2 changes: 1 addition & 1 deletion Sources/SotoCodeGeneratorLib/Smithy+CodeGeneration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ extension MemberShape {
return "String"
} else if memberShape is BlobShape {
if self.hasTrait(type: HttpPayloadTrait.self) { return "AWSHTTPBody" }
else if self.hasTrait(type: EventPayloadTrait.self) { return "ByteBuffer" }
else if self.hasTrait(type: EventPayloadTrait.self) { return "AWSEventPayload" }
return "AWSBase64Data"
} else if memberShape is CollectionShape {
if memberShape.hasTrait(type: StreamingTrait.self) {
Expand Down
8 changes: 4 additions & 4 deletions Sources/SotoCodeGeneratorLib/Templates/enumWithValues.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ extension Templates {
{{/comment}}
case {{variable}}({{type}})
{{/members}}
{{#decode}}
{{#shapeCoding.requiresDecodeInit}}
{{scope}} init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
Expand All @@ -55,8 +55,8 @@ extension Templates {
{{/members}}
}
}
{{/decode}}
{{#isEncodable}}
{{/shapeCoding.requiresDecodeInit}}
{{#shapeCoding.requiresEncode}}
{{scope}} func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
Expand All @@ -67,7 +67,7 @@ extension Templates {
{{/members}}
}
}
{{/isEncodable}}
{{/shapeCoding.requiresEncode}}
{{! validate() function }}
{{#first(validation)}}
Expand Down
Loading

0 comments on commit b56a64f

Please sign in to comment.