Skip to content

Commit

Permalink
Pass participant kind in the grant. (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc authored Jan 16, 2024
1 parent 8ad4c79 commit 2e48332
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 0 deletions.
7 changes: 7 additions & 0 deletions auth/accesstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"

"github.com/livekit/protocol/livekit"
)

const (
Expand Down Expand Up @@ -55,6 +57,11 @@ func (t *AccessToken) SetName(name string) *AccessToken {
return t
}

func (t *AccessToken) SetKind(kind livekit.ParticipantInfo_Kind) *AccessToken {
t.grant.SetParticipantKind(kind)
return t
}

func (t *AccessToken) AddGrant(grant *VideoGrant) *AccessToken {
t.grant.Video = grant
return t
Expand Down
20 changes: 20 additions & 0 deletions auth/accesstoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/require"

"github.com/livekit/protocol/livekit"
"github.com/livekit/protocol/utils"
)

Expand All @@ -40,6 +41,7 @@ func TestAccessToken(t *testing.T) {
at := NewAccessToken(apiKey, secret).
AddGrant(videoGrant).
SetValidFor(time.Minute * 5).
SetKind(livekit.ParticipantInfo_AGENT).
SetIdentity("user")
value, err := at.ToJWT()
//fmt.Println(raw)
Expand All @@ -55,9 +57,27 @@ func TestAccessToken(t *testing.T) {
err = token.UnsafeClaimsWithoutVerification(&decodedGrant)
require.NoError(t, err)

require.EqualValues(t, livekit.ParticipantInfo_AGENT, decodedGrant.GetParticipantKind())
require.EqualValues(t, videoGrant, decodedGrant.Video)
})

t.Run("missing kind should be interpreted as standard", func(t *testing.T) {
apiKey, secret := apiKeypair()
value, err := NewAccessToken(apiKey, secret).
AddGrant(&VideoGrant{RoomJoin: true, Room: "myroom"}).
ToJWT()
require.NoError(t, err)
token, err := jwt.ParseSigned(value)
require.NoError(t, err)

decodedGrant := ClaimGrants{}
err = token.UnsafeClaimsWithoutVerification(&decodedGrant)
require.NoError(t, err)

// default validity
require.EqualValues(t, livekit.ParticipantInfo_STANDARD, decodedGrant.GetParticipantKind())
})

t.Run("default validity should be more than a minute", func(t *testing.T) {
apiKey, secret := apiKeypair()
videoGrant := &VideoGrant{RoomJoin: true, Room: "myroom"}
Expand Down
30 changes: 30 additions & 0 deletions auth/grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ type VideoGrant struct {
type ClaimGrants struct {
Identity string `json:"-"`
Name string `json:"name,omitempty"`
Kind string `json:"kind,omitempty"`
Video *VideoGrant `json:"video,omitempty"`
// for verifying integrity of the message body
Sha256 string `json:"sha256,omitempty"`
Metadata string `json:"metadata,omitempty"`
}

func (c *ClaimGrants) SetParticipantKind(kind livekit.ParticipantInfo_Kind) {
c.Kind = kindFromProto(kind)
}

func (c *ClaimGrants) GetParticipantKind() livekit.ParticipantInfo_Kind {
return kindToProto(c.Kind)
}

func (c *ClaimGrants) Clone() *ClaimGrants {
if c == nil {
return nil
Expand Down Expand Up @@ -271,3 +280,24 @@ func sourceToProto(sourceStr string) livekit.TrackSource {
return livekit.TrackSource_UNKNOWN
}
}

func kindFromProto(source livekit.ParticipantInfo_Kind) string {
return strings.ToLower(source.String())
}

func kindToProto(sourceStr string) livekit.ParticipantInfo_Kind {
switch sourceStr {
case "", "standard":
return livekit.ParticipantInfo_STANDARD
case "ingress":
return livekit.ParticipantInfo_INGRESS
case "egress":
return livekit.ParticipantInfo_EGRESS
case "sip":
return livekit.ParticipantInfo_SIP
case "agent":
return livekit.ParticipantInfo_AGENT
default:
return livekit.ParticipantInfo_STANDARD
}
}
18 changes: 18 additions & 0 deletions auth/grants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ package auth

import (
"reflect"
"strconv"
"testing"

"github.com/stretchr/testify/require"

"github.com/livekit/protocol/livekit"
)

func TestGrants(t *testing.T) {
Expand Down Expand Up @@ -66,6 +69,7 @@ func TestGrants(t *testing.T) {
grants := &ClaimGrants{
Identity: "identity",
Name: "name",
Kind: "kind",
Video: video,
Sha256: "sha256",
Metadata: "metadata",
Expand All @@ -80,3 +84,17 @@ func TestGrants(t *testing.T) {
require.True(t, reflect.DeepEqual(grants.Video, clone.Video))
})
}

func TestParticipantKind(t *testing.T) {
const kindMin, kindMax = livekit.ParticipantInfo_STANDARD, livekit.ParticipantInfo_AGENT
for k := kindMin; k <= kindMax; k++ {
k := k
t.Run(k.String(), func(t *testing.T) {
require.Equal(t, k, kindToProto(kindFromProto(k)))
})
}
const kindNext = kindMax + 1
if _, err := strconv.Atoi(kindNext.String()); err != nil {
t.Errorf("Please update kindMax to match protobuf. Missing value: %s", kindNext)
}
}
2 changes: 2 additions & 0 deletions egress/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
)

func BuildEgressToken(egressID, apiKey, secret, roomName string) (string, error) {
Expand All @@ -36,6 +37,7 @@ func BuildEgressToken(egressID, apiKey, secret, roomName string) (string, error)
at := auth.NewAccessToken(apiKey, secret).
AddGrant(grant).
SetIdentity(egressID).
SetKind(livekit.ParticipantInfo_EGRESS).
SetValidFor(24 * time.Hour)

return at.ToJWT()
Expand Down
2 changes: 2 additions & 0 deletions ingress/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
)

func BuildIngressToken(apiKey, secret, roomName, participantIdentity, participantName string) (string, error) {
Expand All @@ -34,6 +35,7 @@ func BuildIngressToken(apiKey, secret, roomName, participantIdentity, participan
AddGrant(grant).
SetIdentity(participantIdentity).
SetName(participantName).
SetKind(livekit.ParticipantInfo_INGRESS).
SetValidFor(24 * time.Hour)

return at.ToJWT()
Expand Down
39 changes: 39 additions & 0 deletions sip/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2023 LiveKit, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sip

import (
"time"

"github.com/livekit/protocol/auth"
"github.com/livekit/protocol/livekit"
)

func BuildSIPToken(apiKey, secret, roomName, participantIdentity, participantName string) (string, error) {
t := true
at := auth.NewAccessToken(apiKey, secret).
AddGrant(&auth.VideoGrant{
RoomJoin: true,
Room: roomName,
CanSubscribe: &t,
CanPublish: &t,
}).
SetIdentity(participantIdentity).
SetName(participantName).
SetKind(livekit.ParticipantInfo_SIP).
SetValidFor(24 * time.Hour)

return at.ToJWT()
}

0 comments on commit 2e48332

Please sign in to comment.