forked from google/safebrowsing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase.go
631 lines (557 loc) · 17.7 KB
/
database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package safebrowsing
import (
"bytes"
"compress/gzip"
"context"
"encoding/gob"
"errors"
"fmt"
"io/ioutil"
"log"
"math"
"math/rand"
"os"
"sync"
"time"
pb "github.com/teamnsrg/safebrowsing/internal/safebrowsing_proto"
"path/filepath"
)
// jitter is the maximum amount of time that we expect an API list update to
// actually take. We add this time to the update period time to give some
// leeway before declaring the database as stale.
const (
maxRetryDelay = 24 * time.Hour
baseRetryDelay = 15 * time.Minute
jitter = 30 * time.Second
)
// database tracks the state of the threat lists published by the Safe Browsing
// API. Since the global blacklist is constantly changing, the contents of the
// database needs to be periodically synced with the Safe Browsing servers in
// order to provide protection for the latest threats.
//
// The process for updating the database is as follows:
// * At startup, if a database file is provided, then load it. If loaded
// properly (not corrupted and not stale), then set tfu as the contents.
// Otherwise, pull a new threat list from the Safe Browsing API.
// * Periodically, synchronize the database with the Safe Browsing API.
// This uses the State fields to update only parts of the threat list that have
// changed since the last sync.
// * Anytime tfu is updated, generate a new tfl.
//
// The process for querying the database is as follows:
// * Check if the requested full hash matches any partial hash in tfl.
// If a match is found, return a set of ThreatDescriptors with a partial match.
type database struct {
config *Config
// threatsForUpdate maps ThreatDescriptors to lists of partial hashes.
// This data structure is in a format that is easily updated by the API.
// It is also the form that is written to disk.
tfu threatsForUpdate
mu sync.Mutex // Protects tfu
// threatsForLookup maps ThreatDescriptors to sets of partial hashes.
// This data structure is in a format that is easily queried.
tfl threatsForLookup
ml sync.RWMutex // Protects tfl, err, and last
err error // Last error encountered
readyCh chan struct{} // Used for waiting until not in an error state.
last time.Time // Last time the threat list were synced
updateAPIErrors uint // Number of times we attempted to contact the api and failed
log *log.Logger
}
type threatsForUpdate map[ThreatDescriptor]partialHashes
type partialHashes struct {
// Since the Hashes field is only needed when storing to disk and when
// updating, this field is cleared except for when it is in use.
// This is done to reduce memory usage as the contents of this can be
// regenerated from the tfl.
Hashes hashPrefixes
SHA256 []byte // The SHA256 over Hashes
State []byte // Arbitrary binary blob to synchronize state with API
}
type threatsForLookup map[ThreatDescriptor]hashSet
// databaseFormat is a light struct used only for gob encoding and decoding.
// As written to disk, the format of the database file is basically the gzip
// compressed version of the gob encoding of databaseFormat.
type databaseFormat struct {
Table threatsForUpdate
Time time.Time
}
// Init initializes the database from the specified file in config.DBPath.
// It reports true if the database was successfully loaded.
func (db *database) Init(config *Config, logger *log.Logger) bool {
db.mu.Lock()
defer db.mu.Unlock()
db.setError(errors.New("not intialized"))
db.config = config
db.log = logger
if db.config.DBPath == "" {
db.log.Printf("no database file specified")
db.setError(errors.New("no database loaded"))
return false
}
dbf, err := loadDatabase(db.config.DBPath)
if err != nil {
db.log.Printf("load failure: %v", err)
db.setError(err)
return false
}
// Validate that the database threat list stored on disk is at least a
// superset of the specified configuration.
tfuNew := make(threatsForUpdate)
for _, td := range db.config.ThreatLists {
if row, ok := dbf.Table[td]; ok {
tfuNew[td] = row
} else {
db.log.Printf("database configuration mismatch, missing %v", td)
db.setError(errors.New("database configuration mismatch"))
return false
}
}
db.tfu = tfuNew
partialLengths := make(map[int]int)
for _, dv := range db.tfu {
for _, partial := range dv.Hashes {
length := len(partial)
if _, ok := partialLengths[length]; !ok {
partialLengths[length] = 0
}
partialLengths[length] += 1
}
}
for length, count := range partialLengths {
searchSpace := math.Pow(256, float64(length))
db.log.Printf("Length: %d, count %d\n", length, count)
db.log.Printf("Likelihood of random hit: %f\n", float64(count)/searchSpace)
}
db.generateThreatsForLookups(dbf.Time)
return true
}
// Status reports the health of the database. The database is considered faulted
// if there was an error during update or if the last update has gone stale. If
// in a faulted state, the db may repair itself on the next Update.
func (db *database) Status() error {
db.ml.RLock()
defer db.ml.RUnlock()
if db.err != nil {
return db.err
}
return nil
}
// UpdateLag reports the amount of time in between when we expected to run
// a database update and the current time
func (db *database) UpdateLag() time.Duration {
lag := db.SinceLastUpdate()
if lag < db.config.UpdatePeriod {
return 0
}
return lag - db.config.UpdatePeriod
}
// SinceLastUpdate gives the duration since the last database update
func (db *database) SinceLastUpdate() time.Duration {
db.ml.RLock()
defer db.ml.RUnlock()
return db.config.now().Sub(db.last)
}
// Ready returns a channel that's closed when the database is ready for queries.
func (db *database) Ready() <-chan struct{} {
return db.readyCh
}
// Update synchronizes the local threat lists with those maintained by the
// global Safe Browsing API servers. If the update is successful, Status should
// report a nil error.
func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
db.mu.Lock()
defer db.mu.Unlock()
// Construct the request.
var numTypes int
var s []*pb.FetchThreatListUpdatesRequest_ListUpdateRequest
for _, td := range db.config.ThreatLists {
var state []byte
if row, ok := db.tfu[td]; ok {
state = row.State
}
s = append(s, &pb.FetchThreatListUpdatesRequest_ListUpdateRequest{
ThreatType: pb.ThreatType(td.ThreatType),
PlatformType: pb.PlatformType(td.PlatformType),
ThreatEntryType: pb.ThreatEntryType(td.ThreatEntryType),
Constraints: &pb.FetchThreatListUpdatesRequest_ListUpdateRequest_Constraints{
SupportedCompressions: db.config.compressionTypes},
State: state,
})
numTypes++
}
req := &pb.FetchThreatListUpdatesRequest{
Client: &pb.ClientInfo{
ClientId: db.config.ID,
ClientVersion: db.config.Version,
},
ListUpdateRequests: s,
}
// Query the API for the threat list and update the database.
last := db.config.now()
resp, err := api.ListUpdate(ctx, req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
// backoff strategy: MIN((2**N-1 * 15 minutes) * (RAND + 1), 24 hours)
n := 1 << db.updateAPIErrors
delay := time.Duration(float64(n) * (rand.Float64() + 1) * float64(baseRetryDelay))
if delay > maxRetryDelay {
delay = maxRetryDelay
}
db.updateAPIErrors++
return delay, false
}
db.updateAPIErrors = 0
nextUpdateWait := db.config.UpdatePeriod
if resp.MinimumWaitDuration != nil {
serverMinWait := time.Duration(resp.MinimumWaitDuration.Seconds)*time.Second + time.Duration(resp.MinimumWaitDuration.Nanos)
if serverMinWait > nextUpdateWait {
nextUpdateWait = serverMinWait
db.log.Printf("Server requested next update in %v", nextUpdateWait)
}
}
if len(resp.ListUpdateResponses) != numTypes {
db.setError(errors.New("safebrowsing: threat list count mismatch"))
db.log.Printf("invalid server response: got %d, want %d threat lists",
len(resp.ListUpdateResponses), numTypes)
return nextUpdateWait, false
}
// Update the threat database with the response.
db.generateThreatsForUpdate()
if err := db.tfu.update(resp); err != nil {
db.setError(err)
db.log.Printf("update failure: %v", err)
db.tfu = nil
return nextUpdateWait, false
}
dbf := databaseFormat{make(threatsForUpdate), last}
for td, phs := range db.tfu {
// Copy of partialHashes before generateThreatsForLookups clobbers it.
dbf.Table[td] = phs
}
db.generateThreatsForLookups(last)
if db.config.DBArchive {
filename := db.config.now().UTC().Format(time.RFC3339) + ".db"
if err := saveDatabase(filepath.Join(db.config.DBArchiveDirectory, filename), dbf); err != nil {
db.log.Printf("save failure: %v", err)
}
}
// Regenerate the database and store it.
if db.config.DBPath != "" {
// Semantically, we ignore save errors, but we do log them.
if err := saveDatabase(db.config.DBPath, dbf); err != nil {
db.log.Printf("save failure: %v", err)
}
}
return nextUpdateWait, true
}
// Lookup looks up the full hash in the threat list and returns a partial
// hash and a set of ThreatDescriptors that may match the full hash.
func (db *database) Lookup(hash hashPrefix) (h hashPrefix, tds []ThreatDescriptor) {
if !hash.IsFull() {
panic("hash is not full")
}
db.ml.RLock()
for td, hs := range db.tfl {
if n := hs.Lookup(hash); n > 0 {
h = hash[:n]
tds = append(tds, td)
}
}
db.ml.RUnlock()
return h, tds
}
// setError clears the database state and sets the last error to be err.
//
// This assumes that the db.mu lock is already held.
func (db *database) setError(err error) {
db.tfu = nil
db.ml.Lock()
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.tfl, db.err, db.last = nil, err, time.Time{}
db.ml.Unlock()
}
// isStale checks whether the last successful update should be considered stale.
// Staleness is defined as being older than two of the configured update periods
// plus jitter.
func (db *database) isStale(lastUpdate time.Time) bool {
if db.config.now().Sub(lastUpdate) > 2*(db.config.UpdatePeriod+jitter) {
return true
}
return false
}
// setStale sets the error state to a stale message, without clearing
// the database state.
//
// This assumes that the db.ml lock is already held.
func (db *database) setStale() {
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.err = errStale
}
// clearError clears the db error state, and unblocks any callers of
// WaitUntilReady.
//
// This assumes that the db.mu lock is already held.
func (db *database) clearError() {
db.ml.Lock()
defer db.ml.Unlock()
if db.err != nil {
close(db.readyCh)
}
db.err = nil
}
// generateThreatsForUpdate regenerates the threatsForUpdate hashes from
// the threatsForLookup. We do this to avoid holding onto the hash lists for
// a long time, needlessly occupying lots of memory.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForUpdate() {
if db.tfu == nil {
db.tfu = make(threatsForUpdate)
}
db.ml.RLock()
for td, hs := range db.tfl {
phs := db.tfu[td]
phs.Hashes = hs.Export()
db.tfu[td] = phs
}
db.ml.RUnlock()
}
// generateThreatsForLookups regenerates the threatsForLookup data structure
// from the threatsForUpdate data structure and stores the last timestamp.
// Since the hashes are effectively stored as a set inside the threatsForLookup,
// we clear out the hashes slice in threatsForUpdate so that it can be GCed.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForLookups(last time.Time) {
tfl := make(threatsForLookup)
for td, phs := range db.tfu {
var hs hashSet
hs.Import(phs.Hashes)
tfl[td] = hs
phs.Hashes = nil // Clear hashes to keep memory usage low
db.tfu[td] = phs
}
db.ml.Lock()
wasBad := db.err != nil
db.tfl, db.last = tfl, last
db.ml.Unlock()
if wasBad {
db.clearError()
db.log.Printf("database is now healthy")
}
}
// saveDatabase saves the database threat list to a file.
func saveDatabase(path string, db databaseFormat) (err error) {
var file *os.File
file, err = os.Create(path)
if err != nil {
return err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewWriterLevel(file, gzip.BestCompression)
if err != nil {
return err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
encoder := gob.NewEncoder(gz)
if err = encoder.Encode(db); err != nil {
return err
}
return nil
}
func deduplicateHashes(prefixes hashPrefixes) hashPrefixes {
newSlice := make([]hashPrefix, 0)
keys := make(map[hashPrefix]bool)
for _, entry := range prefixes {
if _, value := keys[entry]; !value {
keys[entry] = true
newSlice = append(newSlice, entry)
}
}
return newSlice
}
// loadDatabase loads the database state from a file.
func loadDatabase(path string) (db databaseFormat, err error) {
filepaths := make([]string, 0)
fi, err := os.Stat(path)
if err != nil {
fmt.Println(err)
return
}
switch mode := fi.Mode(); {
case mode.IsDir():
//TODO: append all fpaths in directory to filepaths
files, err := ioutil.ReadDir(path)
if err != nil {
log.Fatal(err)
}
for _, f := range files {
filepaths = append(filepaths, filepath.Join(path, f.Name()))
}
case mode.IsRegular():
filepaths = append(filepaths, path)
}
hashCounter := 0
db = databaseFormat{
Table: make(map[ThreatDescriptor]partialHashes),
}
partialLengths := make(map[int]int)
for _, fpath := range filepaths {
fmt.Printf("Loading database file %s \n", fpath)
var file *os.File
file, err = os.Open(fpath)
if err != nil {
return db, err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewReader(file)
if err != nil {
return db, err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
var tempDb databaseFormat
decoder := gob.NewDecoder(gz)
if err = decoder.Decode(&tempDb); err != nil {
return tempDb, err
}
for descriptor, dv := range tempDb.Table {
if !bytes.Equal(dv.SHA256, dv.Hashes.SHA256()) {
return db, errors.New("safebrowsing: threat list SHA256 mismatch")
}
if _, ok := db.Table[descriptor]; ok {
//fmt.Printf("Original Table size: %d\n", db.Table[descriptor].Hashes.Len())
pH := db.Table[descriptor]
combinedHashes := append(pH.Hashes, dv.Hashes...)
db.Table[descriptor] = partialHashes{
Hashes: combinedHashes,
}
//TODO: figure out how to append timestamps
} else {
db.Table[descriptor] = partialHashes{
Hashes: make(hashPrefixes, dv.Hashes.Len()),
}
copy(db.Table[descriptor].Hashes, dv.Hashes)
}
hashCounter += db.Table[descriptor].Hashes.Len()
}
}
for descriptor, dv := range db.Table {
db.Table[descriptor].Hashes.Sort()
db.Table[descriptor] = partialHashes{
Hashes: deduplicateHashes(db.Table[descriptor].Hashes),
}
copy(db.Table[descriptor].SHA256, dv.Hashes.SHA256())
}
for length, count := range partialLengths {
searchSpace := math.Pow(256, float64(length))
fmt.Printf("Length: %d, count %d\n", length, count)
fmt.Printf("Likelihood of random hit: %f\n", float64(count)/searchSpace)
}
fmt.Printf("Loaded database of %d hash prefixes from %d files\n", hashCounter, len(filepaths))
return db, nil
}
// update updates the threat list according to the API response.
func (tfu threatsForUpdate) update(resp *pb.FetchThreatListUpdatesResponse) error {
// For each update response do the removes and adds
for _, m := range resp.GetListUpdateResponses() {
td := ThreatDescriptor{
PlatformType: PlatformType(m.PlatformType),
ThreatType: ThreatType(m.ThreatType),
ThreatEntryType: ThreatEntryType(m.ThreatEntryType),
}
phs, ok := tfu[td]
switch m.ResponseType {
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_PARTIAL_UPDATE:
if !ok {
return errors.New("safebrowsing: partial update received for non-existent key")
}
case pb.FetchThreatListUpdatesResponse_ListUpdateResponse_FULL_UPDATE:
if len(m.Removals) > 0 {
return errors.New("safebrowsing: indices to be removed included in a full update")
}
phs = partialHashes{}
default:
return errors.New("safebrowsing: unknown response type")
}
// Hashes must be sorted for removal logic to work properly.
phs.Hashes.Sort()
for _, removal := range m.Removals {
idxs, err := decodeIndices(removal)
if err != nil {
return err
}
for _, i := range idxs {
if i < 0 || i >= int32(len(phs.Hashes)) {
return errors.New("safebrowsing: invalid removal index")
}
phs.Hashes[i] = ""
}
}
// If any removal was performed, compact the list of hashes.
if len(m.Removals) > 0 {
compactHashes := phs.Hashes[:0]
for _, h := range phs.Hashes {
if h != "" {
compactHashes = append(compactHashes, h)
}
}
phs.Hashes = compactHashes
}
for _, addition := range m.Additions {
hashes, err := decodeHashes(addition)
if err != nil {
return err
}
phs.Hashes = append(phs.Hashes, hashes...)
}
// Hashes must be sorted for SHA256 checksum to be correct.
phs.Hashes.Sort()
if err := phs.Hashes.Validate(); err != nil {
return err
}
if cs := m.GetChecksum(); cs != nil {
phs.SHA256 = cs.Sha256
}
if !bytes.Equal(phs.SHA256, phs.Hashes.SHA256()) {
return errors.New("safebrowsing: threat list SHA256 mismatch")
}
phs.State = m.NewClientState
tfu[td] = phs
}
return nil
}