From 74537b7fbe826e701f35e6e2601cba9ea7f27e88 Mon Sep 17 00:00:00 2001 From: Kyle Date: Sun, 28 Apr 2024 12:45:35 -0700 Subject: [PATCH] simplify marshal/ummarshal logic, make parser stricter, update tests, fix readme (#5) --- duration.go | 27 ++++++++++++++++----------- duration_test.go | 10 ++++++++-- readme.md | 2 +- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/duration.go b/duration.go index 1723553..27210f5 100644 --- a/duration.go +++ b/duration.go @@ -1,10 +1,12 @@ package duration import ( + "encoding/json" "errors" "fmt" "math" "strconv" + "strings" "time" "unicode" ) @@ -47,6 +49,10 @@ var ( // Parse attempts to parse the given duration string into a *Duration, // if parsing fails an error is returned instead func Parse(d string) (*Duration, error) { + if !strings.Contains(d, "P") { + return nil, ErrUnexpectedInput + } + state := parsingPeriod duration := &Duration{} num := "" @@ -277,25 +283,24 @@ func (duration *Duration) String() string { return d } +// MarshalJSON satisfies the Marshaler interface by return a valid JSON string representation of the duration func (duration Duration) MarshalJSON() ([]byte, error) { - return []byte("\"" + duration.String() + "\""), nil + return json.Marshal(duration.String()) } +// UnmarshalJSON satisfies the Unmarshaler interface by return a valid JSON string representation of the duration func (duration *Duration) UnmarshalJSON(source []byte) error { - strVal := string(source) - if len(strVal) < 2 { - return fmt.Errorf("invalid ISO 8601 duration: %s", strVal) - } - strVal = strVal[1 : len(strVal)-1] - - if strVal == "null" { - return nil + durationString := "" + err := json.Unmarshal(source, &durationString) + if err != nil { + return err } - parsed, err := Parse(strVal) + parsed, err := Parse(durationString) if err != nil { - return fmt.Errorf("invalid ISO 8601 duration: %s", strVal) + return fmt.Errorf("failed to parse duration: %w", err) } + *duration = *parsed return nil } diff --git a/duration_test.go b/duration_test.go index c2b6e0f..d938ec0 100644 --- a/duration_test.go +++ b/duration_test.go @@ -17,6 +17,12 @@ func TestParse(t *testing.T) { want *Duration wantErr bool }{ + { + name: "invalid-duration-1", + args: args{d: "T0S"}, + want: nil, + wantErr: true, + }, { name: "period-only", args: args{d: "P4Y"}, @@ -27,7 +33,7 @@ func TestParse(t *testing.T) { }, { name: "time-only-decimal", - args: args{d: "T2.5S"}, + args: args{d: "PT2.5S"}, want: &Duration{ Seconds: 2.5, }, @@ -245,7 +251,7 @@ func TestDuration_String(t *testing.T) { t.Errorf("expected: %s, got: %s", "P3Y6M4DT12H30M33.3333S", duration.String()) } - smallDuration, err := Parse("T0.0000000000001S") + smallDuration, err := Parse("PT0.0000000000001S") if err != nil { t.Fatal(err) } diff --git a/readme.md b/readme.md index ec987dc..a1d9f37 100644 --- a/readme.md +++ b/readme.md @@ -36,7 +36,7 @@ func main() { fmt.Println(d.Minutes) // 30 fmt.Println(d.Seconds) // 5.5 - d, err = duration.Parse("T33.3S") + d, err = duration.Parse("PT33.3S") if err != nil { panic(err) }