From bd0ec14fc20115e8f1aa715cda565764f067f4fa Mon Sep 17 00:00:00 2001 From: Oleg Morozenkov Date: Mon, 23 Sep 2024 15:05:28 +0300 Subject: [PATCH 1/2] Fix remarshal of downstream tarantool --- result.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/result.go b/result.go index a696cb8..e91f12c 100644 --- a/result.go +++ b/result.go @@ -9,6 +9,9 @@ type Result struct { ErrorCode uint Error error Data [][]interface{} + DataBytes []byte + + marshaller msgp.Marshaler } func (r *Result) GetCommandID() uint { @@ -20,6 +23,22 @@ func (r *Result) GetCommandID() uint { // MarshalMsg implements msgp.Marshaler func (r *Result) MarshalMsg(b []byte) (o []byte, err error) { + if r.marshaller == nil { + r.marshaller = defaultResultMarshaller{Result: r} + } + return r.marshaller.MarshalMsg(b) +} + +func (r *Result) WithBytesMarshaller() *Result { + r.marshaller = bytesResultMarshaller{Result: r} + return r +} + +type defaultResultMarshaller struct { + *Result +} + +func (r defaultResultMarshaller) MarshalMsg(b []byte) (o []byte, err error) { o = b if r.Error != nil { o = msgp.AppendMapHeader(o, 1) @@ -40,6 +59,29 @@ func (r *Result) MarshalMsg(b []byte) (o []byte, err error) { return o, nil } +type bytesResultMarshaller struct { + *Result +} + +func (r bytesResultMarshaller) MarshalMsg(b []byte) (o []byte, err error) { + o = b + if r.Error != nil { + o = msgp.AppendMapHeader(o, 1) + o = msgp.AppendUint(o, KeyError) + o = msgp.AppendString(o, r.Error.Error()) + } else { + o = msgp.AppendMapHeader(o, 1) + o = msgp.AppendUint(o, KeyData) + if len(r.DataBytes) != 0 { + o = append(o, r.DataBytes...) + } else { + o = msgp.AppendArrayHeader(o, 0) + } + } + + return o, nil +} + // UnmarshalMsg implements msgp.Unmarshaler func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { var l uint32 @@ -70,6 +112,8 @@ func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { case KeyData: var i, j uint32 + bufData := buf + if dl, buf, err = msgp.ReadArrayHeaderBytes(buf); err != nil { return } @@ -96,18 +140,26 @@ func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { } } } + + bufRead := len(bufData) - len(buf) + bufData = bufData[:bufRead] + r.DataBytes = make([]byte, len(bufData)) + copy(r.DataBytes, bufData) + case KeyError: errorMessage, buf, err = msgp.ReadStringBytes(buf) if err != nil { return } r.Error = NewQueryError(r.ErrorCode, errorMessage) + default: if buf, err = msgp.Skip(buf); err != nil { return } } } + return } From 6a8a6b4b644854ce7eb45ffeeff3442e819d871c Mon Sep 17 00:00:00 2001 From: speshal71 Date: Thu, 14 Nov 2024 02:37:27 +0300 Subject: [PATCH 2/2] Fixed nil query for server for unknown request + improved raw marshaling of result --- binpacket.go | 1 + packet.go | 17 +++++++++++++++-- query.go | 2 +- result.go | 47 +++++++++++++++++++++++------------------------ result_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ server.go | 8 ++++++++ unknown_query.go | 36 ++++++++++++++++++++++++++++++++++++ 7 files changed, 131 insertions(+), 27 deletions(-) create mode 100644 result_test.go create mode 100644 unknown_query.go diff --git a/binpacket.go b/binpacket.go index 272ef69..fcbebfc 100644 --- a/binpacket.go +++ b/binpacket.go @@ -58,6 +58,7 @@ func (pp *BinaryPacket) Reset() { pp.packet.SchemaID = 0 pp.packet.requestID = 0 pp.packet.Result = nil + pp.packet.opts = PacketOpts{} pp.body = pp.body[:0] } diff --git a/packet.go b/packet.go index 7302ddb..f8ac80b 100644 --- a/packet.go +++ b/packet.go @@ -7,6 +7,10 @@ import ( "github.com/tinylib/msgp/msgp" ) +type PacketOpts struct { + asQuery bool +} + type Packet struct { Cmd uint LSN uint64 @@ -16,6 +20,15 @@ type Packet struct { Timestamp time.Time Request Query Result *Result + + opts PacketOpts +} + +// AsQuery forces packet to be unmarshaled as query even if it's not supported. +func (pack *Packet) AsQuery() *Packet { + pack.opts.asQuery = true + + return pack } func (pack *Packet) String() string { @@ -114,7 +127,7 @@ func (pack *Packet) UnmarshalBinaryBody(data []byte) (buf []byte, err error) { return unpackr(pack.Cmd^ErrorFlag, data) } - if q := NewQuery(pack.Cmd); q != nil { + if q := NewQuery(pack.Cmd); IsKnownQuery(q) || pack.opts.asQuery { return unpackq(q, data) } return unpackr(OKCommand, data) @@ -128,7 +141,7 @@ func (pack *Packet) UnmarshalBinary(data []byte) error { // UnmarshalMsg implements msgp.Unmarshaler func (pack *Packet) UnmarshalMsg(data []byte) (buf []byte, err error) { - *pack = Packet{} + *pack = Packet{opts: pack.opts} buf = data diff --git a/query.go b/query.go index ead1759..3beb6ba 100644 --- a/query.go +++ b/query.go @@ -33,6 +33,6 @@ func NewQuery(cmd uint) Query { case EvalCommand: return &Eval{} default: - return nil + return NewUnknownQuery(cmd) } } diff --git a/result.go b/result.go index e91f12c..efd9566 100644 --- a/result.go +++ b/result.go @@ -2,14 +2,16 @@ package tarantool import ( "fmt" + "github.com/tinylib/msgp/msgp" ) type Result struct { + RawBytes []byte + ErrorCode uint Error error Data [][]interface{} - DataBytes []byte marshaller msgp.Marshaler } @@ -29,6 +31,16 @@ func (r *Result) MarshalMsg(b []byte) (o []byte, err error) { return r.marshaller.MarshalMsg(b) } +// WithBytesMarshaller changes the marshaller for result serialization. +// +// Current implementation of unmarshaller may change structure of the result (e.g. in call17) +// if it's not array of tuples in which case it's forcefully wrapped. It also skips +// unknown keys. Therefore serialized sequence of bytes produced by the default marshaller +// is different from the incoming. +// +// Bytes marshaller on the other hand returns exactly the same array +// the result was successfully unmarshalled from (preserving all the keys of the body +// including unknown ones). But it won't reflect any manual changes of unmarshalled data. func (r *Result) WithBytesMarshaller() *Result { r.marshaller = bytesResultMarshaller{Result: r} return r @@ -64,22 +76,7 @@ type bytesResultMarshaller struct { } func (r bytesResultMarshaller) MarshalMsg(b []byte) (o []byte, err error) { - o = b - if r.Error != nil { - o = msgp.AppendMapHeader(o, 1) - o = msgp.AppendUint(o, KeyError) - o = msgp.AppendString(o, r.Error.Error()) - } else { - o = msgp.AppendMapHeader(o, 1) - o = msgp.AppendUint(o, KeyData) - if len(r.DataBytes) != 0 { - o = append(o, r.DataBytes...) - } else { - o = msgp.AppendArrayHeader(o, 0) - } - } - - return o, nil + return append(b, r.RawBytes...), nil } // UnmarshalMsg implements msgp.Unmarshaler @@ -95,6 +92,15 @@ func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { if len(buf) == 0 && r.ErrorCode == OKCommand { return buf, nil } + + defer func() { + if err == nil { + rawPacketLength := len(data) - len(buf) + r.RawBytes = make([]byte, rawPacketLength) + copy(r.RawBytes, data[:rawPacketLength]) + } + }() + l, buf, err = msgp.ReadMapHeaderBytes(buf) if err != nil { @@ -112,8 +118,6 @@ func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { case KeyData: var i, j uint32 - bufData := buf - if dl, buf, err = msgp.ReadArrayHeaderBytes(buf); err != nil { return } @@ -141,11 +145,6 @@ func (r *Result) UnmarshalMsg(data []byte) (buf []byte, err error) { } } - bufRead := len(bufData) - len(buf) - bufData = bufData[:bufRead] - r.DataBytes = make([]byte, len(bufData)) - copy(r.DataBytes, bufData) - case KeyError: errorMessage, buf, err = msgp.ReadStringBytes(buf) if err != nil { diff --git a/result_test.go b/result_test.go new file mode 100644 index 0000000..3fc839d --- /dev/null +++ b/result_test.go @@ -0,0 +1,47 @@ +package tarantool + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResultMarshaling(t *testing.T) { + var result Result + + // The result of a call17 to: + // function a() + // return "a" + // end + tntBodyBytes := []byte{ + 0x81, // MP_MAP + 0x30, // key IPROTO_DATA + 0xdd, 0x0, 0x0, 0x0, 0x1, // MP_ARRAY + 0xa1, 0x61, // string value "a" + } + + expectedDefaultMarshalerBytes := []byte{ + 0x81, // MP_MAP + 0x30, // key IPROTO_DATA + 0x91, // MP_ARRAY + 0x91, // MP_ARRAY + 0xa1, 0x61, // string value "a" + } + + buf, err := result.UnmarshalMsg(tntBodyBytes) + assert.NoError(t, err, "error unmarshaling result") + assert.Empty(t, buf, "unmarshaling residual buffer is not empty") + + defaultMarshalerRes, err := result.MarshalMsg(nil) + assert.NoError(t, err, "error marshaling by default marshaller") + assert.Equal( + t, + expectedDefaultMarshalerBytes, + defaultMarshalerRes, + ) + + result.WithBytesMarshaller() + bytesMarshalerRes, err := result.MarshalMsg(nil) + assert.NoError(t, err, "error marshaling by bytes marshaller") + assert.Equal(t, tntBodyBytes, bytesMarshalerRes) +} diff --git a/server.go b/server.go index eb58ec3..970b0e2 100644 --- a/server.go +++ b/server.go @@ -34,11 +34,15 @@ type IprotoServer struct { schemaID uint64 wg sync.WaitGroup getPingStatus func(*IprotoServer) uint + + // asQueryServer forces incoming requests to be parsed as queries + asQueryServer bool } type IprotoServerOptions struct { Perf PerfCount GetPingStatus func(*IprotoServer) uint + AsQueryServer bool } func NewIprotoServer(uuid string, handler QueryHandler, onShutdown OnShutdownCallback) *IprotoServer { @@ -58,6 +62,7 @@ func (s *IprotoServer) WithOptions(opts *IprotoServerOptions) *IprotoServer { opts = &IprotoServerOptions{} } s.perf = opts.Perf + s.asQueryServer = opts.AsQueryServer s.getPingStatus = opts.GetPingStatus if s.getPingStatus == nil { s.getPingStatus = func(*IprotoServer) uint { return 0 } @@ -216,6 +221,9 @@ READER_LOOP: wg.Add(1) go func(pp *BinaryPacket) { packet := &pp.packet + if s.asQueryServer { + packet.AsQuery() + } defer wg.Done() err := packet.UnmarshalBinary(pp.body) diff --git a/unknown_query.go b/unknown_query.go new file mode 100644 index 0000000..7ce748e --- /dev/null +++ b/unknown_query.go @@ -0,0 +1,36 @@ +package tarantool + +type UnknownQuery struct { + cmd uint + data []byte +} + +var _ Query = (*UnknownQuery)(nil) + +func NewUnknownQuery(cmd uint) *UnknownQuery { + return &UnknownQuery{cmd: cmd} +} + +func (q *UnknownQuery) GetCommandID() uint { + return q.cmd +} + +func (q *UnknownQuery) MarshalMsg(b []byte) ([]byte, error) { + return append(b, q.data...), nil +} + +// UnmarshalMsg saves all of the data into the query. +// So make sure it doesn't contain part of another packet. +func (q *UnknownQuery) UnmarshalMsg(data []byte) (buf []byte, err error) { + q.data = make([]byte, len(data)) + copy(q.data, data) + + return nil, nil +} + +// IsKnownQuery returns true if passed query is known and supported. +func IsKnownQuery(q Query) bool { + _, unknown := q.(*UnknownQuery) + + return q != nil && !unknown +}