From 655c5eac448fe14be440b48696f2d4d443af2da8 Mon Sep 17 00:00:00 2001 From: SpeedReach <37238439+SpeedReach@users.noreply.github.com> Date: Fri, 24 May 2024 19:12:28 +0800 Subject: [PATCH] added force_rebuild for make proto --- internal/services/group/create.go | 2 +- .../services/group/generate_invite_code.go | 81 +++++++++++++++++++ internal/services/group/invite_code_test.go | 45 +++++++++++ internal/services/group/list_joined.go | 2 +- internal/test/group_test.go | 2 +- 5 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 internal/services/group/invite_code_test.go diff --git a/internal/services/group/create.go b/internal/services/group/create.go index bda2b8d..3d51b5f 100644 --- a/internal/services/group/create.go +++ b/internal/services/group/create.go @@ -18,7 +18,7 @@ func createGroup(ctx context.Context, db *sql.DB, name string) (uuid.UUID, error return groupId, err } -func (g Service) CreateGroup(ctx context.Context, req *monify.CreateGroupRequest) (*monify.CreateGroupResponse, error) { +func (s Service) CreateGroup(ctx context.Context, req *monify.CreateGroupRequest) (*monify.CreateGroupResponse, error) { userId := ctx.Value(middlewares.UserIdContextKey{}) if userId == nil { return nil, status.Error(codes.Unauthenticated, "Unauthorized.") diff --git a/internal/services/group/generate_invite_code.go b/internal/services/group/generate_invite_code.go index 654521e..e91e460 100644 --- a/internal/services/group/generate_invite_code.go +++ b/internal/services/group/generate_invite_code.go @@ -1 +1,82 @@ package group + +import ( + "context" + "database/sql" + "github.com/google/uuid" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "math" + "math/rand" + "monify/internal/middlewares" + monify "monify/protobuf" + "time" +) + +const ( + inviteCodeChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + timeDeterLength = 4 + randomLength = 2 + inviteCodeLength = timeDeterLength + randomLength + expiresInterval = int64(time.Minute * 10) +) + +func checkPermission(ctx context.Context, db *sql.DB, groupId uuid.UUID, userId uuid.UUID) (bool, error) { + var count int + err := db.QueryRow(` + SELECT COUNT(*) FROM group_member WHERE group_id = $1 AND user_id = $2 + `, groupId, userId).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +func (s Service) GenerateInviteCode(ctx context.Context, req *monify.GenerateInviteCodeRequest) (*monify.GenerateInviteCodeResponse, error) { + userId, ok := ctx.Value(middlewares.UserIdContextKey{}).(uuid.UUID) + db := ctx.Value(middlewares.StorageContextKey{}).(*sql.DB) + logger := ctx.Value(middlewares.LoggerContextKey{}).(*zap.Logger) + if !ok { + return nil, status.Error(codes.Unauthenticated, "Unauthorized.") + } + groupId, err := uuid.Parse(req.GroupId) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "Invalid group ID") + } + hasPerm, err := checkPermission(ctx, db, groupId, userId) + if err != nil { + logger.Error("Failed to check permission", zap.Error(err)) + return nil, status.Error(codes.Internal, "Internal") + } + if !hasPerm { + return nil, status.Error(codes.PermissionDenied, "Permission denied") + } + inviteCode := generateInviteCode() + _, err = db.Exec(` + INSERT INTO invite_code (group_id, code) VALUES ($1, $2) + `, groupId, inviteCode) + return nil, err +} + +func indexToChar(index int) byte { + return inviteCodeChars[index] +} +func generateInviteCode() string { + charsCount := len(inviteCodeChars) + seed := time.Now().UnixMilli() % expiresInterval + inviteCodeRange := int(math.Pow(float64(charsCount), timeDeterLength)) + code := int(seed) % inviteCodeRange + inviteCode := "" + + for i := 0; i < timeDeterLength; i++ { + index := code % charsCount + code /= charsCount + inviteCode += string(indexToChar(index)) + } + + for i := 0; i < randomLength; i++ { + inviteCode += string(indexToChar(rand.Int() % charsCount)) + } + return inviteCode +} diff --git a/internal/services/group/invite_code_test.go b/internal/services/group/invite_code_test.go new file mode 100644 index 0000000..051eb95 --- /dev/null +++ b/internal/services/group/invite_code_test.go @@ -0,0 +1,45 @@ +package group + +import ( + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestIndexToChar(t *testing.T) { + assert.Equal(t, byte('0'), indexToChar(0), "Expected '0' for index 0") + assert.Equal(t, byte('1'), indexToChar(1), "Expected '1' for index 1") + assert.Equal(t, byte('2'), indexToChar(2), "Expected '2' for index 2") + assert.Equal(t, byte('8'), indexToChar(8), "Expected '8' for index 8") + assert.Equal(t, byte('a'), indexToChar(10), "Expected 'a' for index 10") + assert.Equal(t, byte('b'), indexToChar(11), "Expected 'b' for index 11") + assert.Equal(t, byte('z'), indexToChar(35), "Expected 'z' for index 35") + assert.Equal(t, byte('A'), indexToChar(36), "Expected 'A' for index 36") + assert.Equal(t, byte('B'), indexToChar(37), "Expected 'B' for index 37") + assert.Equal(t, byte('Z'), indexToChar(61), "Expected 'Z' for index 61") +} + +func TestGenerateInviteCodeUniqueness(t *testing.T) { + inviteCode := generateInviteCode() + println(inviteCode) + assert.Len(t, inviteCode, inviteCodeLength, "Expected invite code to be 6 characters long") + for _, char := range inviteCode { + assert.True(t, char >= '0' && char <= '9' || char >= 'a' && char <= 'z' || char >= 'A' && char <= 'Z', "Expected invite code to contain only alphanumeric characters") + } + time.Sleep(time.Millisecond) + inviteCode2 := generateInviteCode() + assert.NotEqual(t, inviteCode, inviteCode2, "Expected two different invite codes") +} + +func TestGenerateInviteCodeRandomness(t *testing.T) { + //try to generate 1000 invite codes and check if 990 of them are unique + failCount := 0 + tries := 1000 + for i := 0; i < tries; i++ { + if generateInviteCode() == generateInviteCode() { + failCount++ + } + } + println(failCount) + assert.LessOrEqual(t, failCount, 10, "Expected 990 out of 1000 invite codes to be unique") +} diff --git a/internal/services/group/list_joined.go b/internal/services/group/list_joined.go index 228430a..4dae115 100644 --- a/internal/services/group/list_joined.go +++ b/internal/services/group/list_joined.go @@ -5,7 +5,7 @@ import ( monify "monify/protobuf" ) -func (g Service) ListJoinedGroups(context.Context, *monify.Empty) (*monify.ListJoinedGroupsResponse, error) { +func (s Service) ListJoinedGroups(context.Context, *monify.Empty) (*monify.ListJoinedGroupsResponse, error) { panic("") } diff --git a/internal/test/group_test.go b/internal/test/group_test.go index c4bbf32..eb823b5 100644 --- a/internal/test/group_test.go +++ b/internal/test/group_test.go @@ -12,5 +12,5 @@ func TestCreateGroup(t *testing.T) { _ = client.CreateTestUser() group, err := client.CreateGroup(context.TODO(), &monify.CreateGroupRequest{Name: "test"}) assert.NoError(t, err) - assert.NotEmpty(t, group.GroupId) + assert.NotEmpty(t, group) }