Skip to content

Commit

Permalink
Add check by user ID endpoint
Browse files Browse the repository at this point in the history
This commit adds a new endpoint to check spam status by user ID. It includes:

- New FindByUserID method in DetectedSpam interface
- Implementation of FindByUserID in storage package
- New checkIdHandler for GET /check/{user_id} route
- Unit tests for new functionality
- Updated README with API documentation
  • Loading branch information
umputun committed Jan 9, 2025
1 parent 8ac0c1a commit 35b6288
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 12 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,24 @@ It is truly a **bad idea** to run the server without basic auth protection, as i
- `user_id` - user id
- `user_name` - username

- `GET /check/{user_id}` - returns status and optional details about detected spammer by user ID.
- Response format:
```json
{
"status": "ham" or "spam",
"checks": { // optional, present only if status is "spam"
"user_id": 123,
"user_name": "spam_user",
"text": "spam text",
"checks": [{"name": "check name is here", "spam": true, "details": "detected because of something"}]
}
}
```
- Status codes:
- `200` - successful response with status and optional details
- `400` - invalid user_id format
- `500` - internal server error during check

- `POST /update/spam` - update spam samples with the message passed in the body. The body should be a json object with the following fields:
- `msg` - spam text

Expand Down
27 changes: 27 additions & 0 deletions app/storage/detected_spam.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package storage

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
Expand Down Expand Up @@ -115,6 +117,31 @@ func (ds *DetectedSpam) Read(ctx context.Context) ([]DetectedSpamInfo, error) {
return entries, nil
}

// FindByUserID returns the latest detected spam entry for the given user ID
func (ds *DetectedSpam) FindByUserID(ctx context.Context, userID int64) (*DetectedSpamInfo, error) {
ds.RLock()
defer ds.RUnlock()

var entry DetectedSpamInfo
err := ds.db.GetContext(ctx, &entry, "SELECT * FROM detected_spam WHERE user_id = ? AND gid = ?"+
" ORDER BY timestamp DESC LIMIT 1", userID, ds.db.GID())
if errors.Is(err, sql.ErrNoRows) {
// not found, return nil *DetectedSpamInfo instead of error
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get detected spam entry for user_id %d: %w", userID, err)
}

var checks []spamcheck.Response
if err := json.Unmarshal([]byte(entry.ChecksJSON), &checks); err != nil {
return nil, fmt.Errorf("failed to unmarshal checks for entry: %w", err)
}
entry.Checks = checks
entry.Timestamp = entry.Timestamp.Local()
return &entry, nil
}

func migrateDetectedSpamTx(ctx context.Context, tx *sqlx.Tx, gid string) error {
// check if gid column exists
var cols []struct {
Expand Down
125 changes: 125 additions & 0 deletions app/storage/detected_spam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,128 @@ func TestDetectedSpam_ReadAfterCleanup(t *testing.T) {
"entries should be ordered by timestamp descending")
}
}

func TestDetectedSpam_FindByUserID(t *testing.T) {
db, teardown := setupTestDB(t)
defer teardown()

ctx := context.Background()
ds, err := NewDetectedSpam(ctx, db)
require.NoError(t, err)

t.Run("user not found", func(t *testing.T) {
entry, err := ds.FindByUserID(ctx, 123)
require.NoError(t, err)
assert.Nil(t, entry)
})

t.Run("basic case", func(t *testing.T) {
ts := time.Now().Truncate(time.Second)
expected := DetectedSpamInfo{
GID: db.GID(),
Text: "test spam",
UserID: 456,
UserName: "spammer",
Timestamp: ts,
}
checks := []spamcheck.Response{{
Name: "test",
Spam: true,
Details: "test details",
}}

err := ds.Write(ctx, expected, checks)
require.NoError(t, err)

entry, err := ds.FindByUserID(ctx, 456)
require.NoError(t, err)
require.NotNil(t, entry)
assert.Equal(t, expected.GID, entry.GID)
assert.Equal(t, expected.Text, entry.Text)
assert.Equal(t, expected.UserID, entry.UserID)
assert.Equal(t, expected.UserName, entry.UserName)
assert.Equal(t, checks, entry.Checks)
assert.Equal(t, ts.Local(), entry.Timestamp)
})

t.Run("multiple entries", func(t *testing.T) {
// write two entries for same user
for i := 0; i < 2; i++ {
entry := DetectedSpamInfo{
GID: db.GID(),
Text: fmt.Sprintf("spam %d", i),
UserID: 789,
UserName: "spammer",
Timestamp: time.Now().Add(time.Duration(i) * time.Hour),
}
checks := []spamcheck.Response{{Name: fmt.Sprintf("check%d", i), Spam: true}}
err := ds.Write(ctx, entry, checks)
require.NoError(t, err)
}

// should get the latest one
entry, err := ds.FindByUserID(ctx, 789)
require.NoError(t, err)
require.NotNil(t, entry)
assert.Equal(t, "spam 1", entry.Text)
assert.Equal(t, "check1", entry.Checks[0].Name)
})

t.Run("invalid checks json", func(t *testing.T) {
// insert invalid json directly to db
_, err := db.Exec(`INSERT INTO detected_spam
(gid, text, user_id, user_name, timestamp, checks)
VALUES (?, ?, ?, ?, ?, ?)`,
db.GID(), "test", 999, "test", time.Now(), "{invalid}")
require.NoError(t, err)

entry, err := ds.FindByUserID(ctx, 999)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal checks")
assert.Nil(t, entry)
})

t.Run("gid isolation", func(t *testing.T) {
// create another db instance with different gid
db2, err := NewSqliteDB(":memory:", "other_gid")
require.NoError(t, err)
defer db2.Close()

ds2, err := NewDetectedSpam(ctx, db2)
require.NoError(t, err)

entry1 := DetectedSpamInfo{
GID: db.GID(),
Text: "spam1",
UserID: 111,
UserName: "spammer1",
}
entry2 := DetectedSpamInfo{
GID: "other_gid",
Text: "spam2",
UserID: 111,
UserName: "spammer2",
}
checks := []spamcheck.Response{{Name: "test", Spam: true}}

// write different entries to each db
err = ds.Write(ctx, entry1, checks)
require.NoError(t, err)
err = ds2.Write(ctx, entry2, checks)
require.NoError(t, err)

// first db should not find entry with other gid
res1, err := ds.FindByUserID(ctx, 111)
require.NoError(t, err)
require.NotNil(t, res1)
assert.Equal(t, "spam1", res1.Text)
assert.Equal(t, db.GID(), res1.GID)

// second db should find its own entry
res2, err := ds2.FindByUserID(ctx, 111)
require.NoError(t, err)
require.NotNil(t, res2)
assert.Equal(t, "spam2", res2.Text)
assert.Equal(t, "other_gid", res2.GID)
})
}
61 changes: 61 additions & 0 deletions app/webapi/mocks/detected_spam.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 47 additions & 4 deletions app/webapi/webapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type Locator interface {
type DetectedSpam interface {
Read(ctx context.Context) ([]storage.DetectedSpamInfo, error)
SetAddedToSamplesFlag(ctx context.Context, id int64) error
FindByUserID(ctx context.Context, userID int64) (*storage.DetectedSpamInfo, error)
}

// NewServer creates a new web API server.
Expand Down Expand Up @@ -167,7 +168,8 @@ func (s *Server) routes(router *chi.Mux) *chi.Mux {
// auth api routes
router.Group(func(authApi chi.Router) {
authApi.Use(s.authMiddleware(rest.BasicAuthWithUserPasswd("tg-spam", s.AuthPasswd)))
authApi.Post("/check", s.checkHandler) // check a message for spam
authApi.Post("/check", s.checkMsgHandler) // check a message for spam
authApi.Get("/check/{user_id}", s.checkIDHandler) // check user id for spam

authApi.Route("/update", func(r chi.Router) { // update spam/ham samples
r.Post("/spam", s.updateSampleHandler(s.SpamFilter.UpdateSpam)) // update spam samples
Expand Down Expand Up @@ -218,10 +220,9 @@ func (s *Server) routes(router *chi.Mux) *chi.Mux {
return router
}

// checkHandler handles POST /check request.
// checkMsgHandler handles POST /check request.
// it gets message text and user id from request body and returns spam status and check results.
func (s *Server) checkHandler(w http.ResponseWriter, r *http.Request) {

func (s *Server) checkMsgHandler(w http.ResponseWriter, r *http.Request) {
type CheckResultDisplay struct {
Spam bool
Checks []spamcheck.Response
Expand Down Expand Up @@ -271,6 +272,48 @@ func (s *Server) checkHandler(w http.ResponseWriter, r *http.Request) {
}
}

// checkIDHandler handles GET /check/{user_id} request.
// it returns JSON with the status "spam" or "ham" for a given user id.
// if user is spammer, it also returns check results.
func (s *Server) checkIDHandler(w http.ResponseWriter, r *http.Request) {
type info struct {
UserName string `json:"user_name,omitempty"`
Message string `json:"message,omitempty"`
Timestamp time.Time `json:"timestamp,omitempty"`
Checks []spamcheck.Response `json:"checks,omitempty"`
}
resp := struct {
Status string `json:"status"`
Info *info `json:"info,omitempty"`
}{
Status: "ham",
}

userID, err := strconv.ParseInt(chi.URLParam(r, "user_id"), 10, 64)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
rest.RenderJSON(w, rest.JSON{"error": "can't parse user id", "details": err.Error()})
return
}

si, err := s.DetectedSpam.FindByUserID(r.Context(), userID)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
rest.RenderJSON(w, rest.JSON{"error": "can't get user info", "details": err.Error()})
return
}
if si != nil {
resp.Status = "spam"
resp.Info = &info{
UserName: si.UserName,
Message: si.Text,
Timestamp: si.Timestamp,
Checks: si.Checks,
}
}
rest.RenderJSON(w, resp)
}

// getDynamicSamplesHandler handles GET /samples request. It returns dynamic samples both for spam and ham.
func (s *Server) getDynamicSamplesHandler(w http.ResponseWriter, _ *http.Request) {
spam, ham, err := s.SpamFilter.DynamicSamples()
Expand Down
Loading

0 comments on commit 35b6288

Please sign in to comment.