-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollection_api.go
143 lines (124 loc) · 4.94 KB
/
collection_api.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package main
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"log"
)
// CollectionAPI defines an interface for MongoDB operations, allowing for testing
type CollectionAPI interface {
InsertOne(ctx context.Context, document interface{}) (*mongo.InsertOneResult, error)
UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error)
DeleteOne(ctx context.Context, filter interface{}) (*mongo.DeleteResult, error)
CountDocuments(ctx context.Context, filter interface{}) (int64, error)
Aggregate(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error)
Drop(ctx context.Context) error
Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error)
}
// MongoDBCollection is a wrapper around mongo.Collection to implement CollectionAPI
type MongoDBCollection struct {
*mongo.Collection
}
func (c *MongoDBCollection) InsertOne(ctx context.Context, document interface{}) (*mongo.InsertOneResult, error) {
return c.Collection.InsertOne(ctx, document)
}
func (c *MongoDBCollection) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return c.Collection.UpdateOne(ctx, filter, update, opts...)
}
func (c *MongoDBCollection) DeleteOne(ctx context.Context, filter interface{}) (*mongo.DeleteResult, error) {
return c.Collection.DeleteOne(ctx, filter)
}
func (c *MongoDBCollection) CountDocuments(ctx context.Context, filter interface{}) (int64, error) {
return c.Collection.CountDocuments(ctx, filter)
}
func (c *MongoDBCollection) Drop(ctx context.Context) error {
return c.Collection.Drop(ctx)
}
func (c *MongoDBCollection) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) {
return c.Collection.Find(ctx, filter, opts...)
}
func (c *MongoDBCollection) Aggregate(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) {
return c.Collection.Aggregate(ctx, pipeline, opts...)
}
func fetchDocumentIDs(collection CollectionAPI, limit int64, testType string) ([]primitive.ObjectID, error) {
var docIDs []primitive.ObjectID
var cursor *mongo.Cursor
var err error
switch testType {
case "insert", "upsert", "delete":
if limit > 0 {
cursor, err = collection.Find(context.Background(), bson.M{}, options.Find().SetProjection(bson.M{"_id": 1}).SetLimit(limit))
} else {
cursor, err = collection.Find(context.Background(), bson.M{}, options.Find().SetProjection(bson.M{"_id": 1}))
}
if err != nil {
return nil, fmt.Errorf("failed to fetch document IDs: %v", err)
}
defer cursor.Close(context.Background())
for cursor.Next(context.Background()) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
log.Printf("Failed to decode document: %v", err)
continue
}
// Check if `_id` is of type `ObjectId` and add to `docIDs`
if id, ok := result["_id"].(primitive.ObjectID); ok {
docIDs = append(docIDs, id)
} else {
log.Printf("Skipping document with unsupported _id type: %T", result["_id"])
}
}
if err := cursor.Err(); err != nil {
return nil, fmt.Errorf("cursor error: %v", err)
}
case "update":
if limit > 0 {
pipeline := []bson.M{{"$sample": bson.M{"size": limit}}}
cursor, err = collection.Aggregate(context.Background(), pipeline)
if err != nil {
return nil, fmt.Errorf("failed to aggregate documents: %v", err)
}
if cursor != nil {
defer cursor.Close(context.Background()) // Only defer if cursor is valid
}
for cursor.Next(context.Background()) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
log.Printf("Failed to decode document: %v", err)
continue
}
// Check if `_id` is of type `ObjectId` and add to `docIDs`
if id, ok := result["_id"].(primitive.ObjectID); ok {
docIDs = append(docIDs, id)
} else {
log.Printf("Skipping document with unsupported _id type: %T", result["_id"])
}
}
} else {
cursor, err = collection.Find(context.Background(), bson.M{}, options.Find().SetProjection(bson.M{"_id": 1}))
if err != nil {
return nil, fmt.Errorf("failed to aggregate documents: %v", err)
}
if cursor != nil {
defer cursor.Close(context.Background()) // Only defer if cursor is valid
}
for cursor.Next(context.Background()) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
log.Printf("Failed to decode document: %v", err)
continue
}
// Check if `_id` is of type `ObjectId` and add to `docIDs`
if id, ok := result["_id"].(primitive.ObjectID); ok {
docIDs = append(docIDs, id)
} else {
log.Printf("Skipping document with unsupported _id type: %T", result["_id"])
}
}
}
}
return docIDs, nil
}