diff --git a/bin.go b/bin.go index a69b03e..693efcf 100644 --- a/bin.go +++ b/bin.go @@ -29,7 +29,7 @@ var ( Invalid = errors.New("invalid value") CantSet = errors.New("can't set") TypeMustBeComparable = errors.New("type must be comparable") - unexpectedBehaviour = errors.New("this is a very unexpected behaviour") + unexpectedBehavior = errors.New("this is a very unexpected behavior") ) func Value(v interface{}) reflect.Value { diff --git a/decoder.go b/decoder.go index febbdcc..1b626cd 100644 --- a/decoder.go +++ b/decoder.go @@ -21,17 +21,11 @@ package bin import ( "io" - "math" "reflect" ) type Decoder struct { - readByte func() (byte, error) - reader io.Reader -} - -func (decoder *Decoder) ReadByte() (byte, error) { - return decoder.readByte() + reader io.Reader } func (decoder *Decoder) Decode(v interface{}) error { @@ -57,12 +51,18 @@ func (decoder *Decoder) Decode(v interface{}) error { value.SetZero() return nil case reflect.Bool: - b, err := decoder.ReadByte() + b := [1]byte{} + + n, err := decoder.reader.Read(b[:]) if err != nil { return err } - if b == 255 { + if n != 1 { + return io.EOF + } + + if b[0] == 255 { value.Set(reflect.ValueOf(true)) return nil } @@ -70,7 +70,7 @@ func (decoder *Decoder) Decode(v interface{}) error { value.Set(reflect.ValueOf(false)) return nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n, err := VarIntOut[int64](decoder) + n, err := VarIntOut[int64](decoder.reader) if err != nil { return err } @@ -78,7 +78,7 @@ func (decoder *Decoder) Decode(v interface{}) error { value.SetInt(n) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - n, err := VarIntOut[uint64](decoder) + n, err := VarIntOut[uint64](decoder.reader) if err != nil { return err } @@ -86,50 +86,50 @@ func (decoder *Decoder) Decode(v interface{}) error { value.SetUint(n) return nil case reflect.Float32: - n, err := VarIntOut[uint32](decoder) + n, err := VarIntOut[uint32](decoder.reader) if err != nil { return err } - value.SetFloat(float64(math.Float32frombits(n))) + value.SetFloat(floatFromBits(n)) return nil case reflect.Float64: - n, err := VarIntOut[uint64](decoder) + n, err := VarIntOut[uint64](decoder.reader) if err != nil { return err } - value.SetFloat(math.Float64frombits(n)) + value.SetFloat(floatFromBits(n)) return nil case reflect.Complex64: - r, err := VarIntOut[uint32](decoder) + r, err := VarIntOut[uint32](decoder.reader) if err != nil { return err } - i, err := VarIntOut[uint32](decoder) + i, err := VarIntOut[uint32](decoder.reader) if err != nil { return err } - value.SetComplex(complex128(complex(math.Float32frombits(r), math.Float32frombits(i)))) + value.SetComplex(complex(floatFromBits(r), floatFromBits(i))) return nil case reflect.Complex128: - r, err := VarIntOut[uint64](decoder) + r, err := VarIntOut[uint64](decoder.reader) if err != nil { return err } - i, err := VarIntOut[uint64](decoder) + i, err := VarIntOut[uint64](decoder.reader) if err != nil { return err } - value.SetComplex(complex(math.Float64frombits(r), math.Float64frombits(i))) + value.SetComplex(complex(floatFromBits(r), floatFromBits(i))) return nil case reflect.Array: for i := 0; i < value.Len(); i++ { - if err := decoder.Decode(value.Index(i)); err != nil { + if err := decoder.Decode(value.Index(i)); err != nil && err != io.EOF { return err } } @@ -148,6 +148,16 @@ func (decoder *Decoder) Decode(v interface{}) error { } if t == nil { + b := [1]byte{} + + if _, err := decoder.reader.Read(b[:]); err != nil && err != io.EOF { + return err + } + + if b[0] != 0 { + return unexpectedBehavior + } + return nil } @@ -164,7 +174,7 @@ func (decoder *Decoder) Decode(v interface{}) error { value.Set(ptr) return nil case reflect.Map: - size, err := VarIntOut[int](decoder) + size, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -197,7 +207,7 @@ func (decoder *Decoder) Decode(v interface{}) error { return decoder.Decode(value) case reflect.Slice: - size, err := VarIntOut[int](decoder) + size, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -212,7 +222,7 @@ func (decoder *Decoder) Decode(v interface{}) error { return nil case reflect.String: - size, err := VarIntOut[int](decoder) + size, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -229,7 +239,7 @@ func (decoder *Decoder) Decode(v interface{}) error { fields := (&Struct{}).fields(value) for i := 0; i < len(fields); i++ { - tag, err := VarIntOut[int](decoder) + tag, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -252,7 +262,7 @@ func (decoder *Decoder) Decode(v interface{}) error { } func (decoder *Decoder) structs(value reflect.Value) error { - size, err := VarIntOut[int](decoder) + size, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -264,7 +274,7 @@ func (decoder *Decoder) structs(value reflect.Value) error { value.Set(reflect.ValueOf(s)) for i := 0; i < size; i++ { - tag, err := VarIntOut[int](decoder) + tag, err := VarIntOut[int](decoder.reader) if err != nil { return err } @@ -295,35 +305,20 @@ func (decoder *Decoder) structs(value reflect.Value) error { } func NewDecoder(reader io.Reader) *Decoder { - var byteReader func() (byte, error) - - v, ok := reader.(io.ByteReader) - if ok { - byteReader = v.ReadByte - } else { - byteReader = func() (byte, error) { - data := make([]byte, 1) - _, err := reader.Read(data) - - return data[0], err - } - } - return &Decoder{ - readByte: byteReader, - reader: reader, + reader: reader, } } func (decoder *Decoder) getType() (reflect.Type, error) { - kind, err := decoder.ReadByte() + kind, err := VarIntOut[int](decoder.reader) if err != nil { return nil, err } switch reflect.Kind(kind) { case reflect.Invalid: - return reflect.TypeOf(nil), nil + return nil, nil case reflect.Bool: return reflect.TypeFor[bool](), nil case reflect.Int: @@ -361,7 +356,7 @@ func (decoder *Decoder) getType() (reflect.Type, error) { case reflect.String: return reflect.TypeFor[string](), nil case reflect.Array: - d, err := VarIntOut[int](decoder) + d, err := VarIntOut[int](decoder.reader) if err != nil { return nil, err } @@ -373,7 +368,7 @@ func (decoder *Decoder) getType() (reflect.Type, error) { var di []int for i := 0; i < d; i++ { - n, err := VarIntOut[int](decoder) + n, err := VarIntOut[int](decoder.reader) if err != nil { return nil, err } @@ -388,7 +383,7 @@ func (decoder *Decoder) getType() (reflect.Type, error) { return fromDepth(t, d, di), nil case reflect.Slice: - d, err := VarIntOut[int](decoder) + d, err := VarIntOut[int](decoder.reader) if err != nil { return nil, err } @@ -402,7 +397,7 @@ func (decoder *Decoder) getType() (reflect.Type, error) { if mixed { for i := 0; i < d; i++ { - n, err := VarIntOut[int](decoder) + n, err := VarIntOut[int](decoder.reader) if err != nil { return nil, err } diff --git a/decoder_test.go b/decoder_test.go index 0ca97ae..c6f2e99 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -1,6 +1,6 @@ /* * A tiny binary format - * Copyright (C) 2024 Dviih + * Copyright (C) 2025 Dviih * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published diff --git a/depth.go b/depth.go index 57788f6..44e150d 100644 --- a/depth.go +++ b/depth.go @@ -1,6 +1,6 @@ /* * A tiny binary format - * Copyright (C) 2024 Dviih + * Copyright (C) 2025 Dviih * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published @@ -21,7 +21,6 @@ package bin import ( "reflect" - "slices" ) func depth(value reflect.Value) (reflect.Type, int, bool, []int) { @@ -76,15 +75,13 @@ func isMixed(t reflect.Type) bool { } func fromDepth(t reflect.Type, d int, di []int) reflect.Type { - slices.Reverse(di) - - for i := 0; i < d; i++ { - if di == nil || di[i] == 0 { + for ; d > 0; d-- { + if di == nil || di[d-1] == 0 { t = reflect.SliceOf(t) continue } - t = reflect.ArrayOf(di[i], t) + t = reflect.ArrayOf(di[d-1], t) } return t diff --git a/encoder.go b/encoder.go index 012a2be..d5eaf3e 100644 --- a/encoder.go +++ b/encoder.go @@ -20,9 +20,7 @@ package bin import ( - "bytes" "io" - "math" "reflect" "strconv" ) @@ -60,25 +58,25 @@ func (encoder *Encoder) Encode(v interface{}) error { return nil case reflect.Float32: - return encoder.Encode(math.Float32bits(float32(value.Float()))) + return encoder.Encode(floatToBits(float32(value.Float()))) case reflect.Float64: - return encoder.Encode(math.Float64bits(value.Float())) + return encoder.Encode(floatToBits(value.Float())) case reflect.Complex64: c := complex64(value.Complex()) - if err := encoder.Encode(math.Float32bits(real(c))); err != nil { + if err := encoder.Encode(floatToBits(real(c))); err != nil { return err } - return encoder.Encode(math.Float32bits(imag(c))) + return encoder.Encode(floatToBits(imag(c))) case reflect.Complex128: c := value.Complex() - if err := encoder.Encode(math.Float64bits(real(c))); err != nil { + if err := encoder.Encode(floatToBits(real(c))); err != nil { return err } - return encoder.Encode(math.Float64bits(imag(c))) + return encoder.Encode(floatToBits(imag(c))) case reflect.Array: for i := 0; i < value.Len(); i++ { if err := encoder.Encode(value.Index(i)); err != nil { @@ -219,7 +217,7 @@ func (encoder *Encoder) Encode(v interface{}) error { return err } - if _, err := io.Copy(encoder.writer, bytes.NewBufferString(value.String())); err != nil { + if _, err := encoder.writer.Write([]byte(value.String())); err != nil { return err } case reflect.Struct: diff --git a/encoder_test.go b/encoder_test.go index ac6592c..08ef663 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -1,6 +1,6 @@ /* * A tiny binary format - * Copyright (C) 2024 Dviih + * Copyright (C) 2025 Dviih * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published diff --git a/varint.go b/varint.go index 7788c82..bd7e992 100644 --- a/varint.go +++ b/varint.go @@ -1,6 +1,6 @@ /* * A tiny binary format - * Copyright (C) 2024 Dviih + * Copyright (C) 2025 Dviih * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published @@ -20,8 +20,8 @@ package bin import ( - "encoding/binary" "io" + "unsafe" ) type Integer = interface { @@ -29,18 +29,88 @@ type Integer = interface { } func VarIntIn[T Integer](writer io.Writer, t T) error { - if _, err := writer.Write(binary.AppendUvarint(nil, uint64(t))); err != nil { + b := make([]byte, 0) + + for int(t) >= 0x80 { + b = append(b, byte(t)|0x80) + t >>= 7 + } + + b = append(b, byte(t)) + + if _, err := writer.Write(b); err != nil { return err } return nil } -func VarIntOut[T Integer](reader io.ByteReader) (T, error) { - t, err := binary.ReadUvarint(reader) - if err != nil { - return 0, err +func VarIntOut[T Integer](reader io.Reader) (T, error) { + var br func() (byte, error) + + if rbr, ok := reader.(io.ByteReader); ok { + br = rbr.ReadByte + } else { + br = func() (byte, error) { + b := [1]byte{} + + n, err := reader.Read(b[:]) + if err != nil { + return 0, err + } + + if n != 1 { + return 0, io.EOF + } + + return b[0], nil + } + } + + var t T + var p uint64 + + for i := 0; i < 10; i++ { + b, err := br() + if err != nil { + return 0, err + } + + if b < 0x80 { + if i == 9 && b > 1 { + return 0, io.EOF + } + + return t | T(b)<