package repositoryimp import ( "context" "errors" "fmt" "github.com/tech/sendico/pkg/db/repository/builder" rd "github.com/tech/sendico/pkg/db/repository/decoder" "github.com/tech/sendico/pkg/db/storable" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/model" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) type MongoRepository struct { collectionName string collection *mongo.Collection } func idFilter(id primitive.ObjectID) bson.D { return bson.D{ {Key: storable.IDField, Value: id}, } } func NewMongoRepository(db *mongo.Database, collection string) *MongoRepository { return &MongoRepository{ collectionName: collection, collection: db.Collection(collection), } } func (r *MongoRepository) Collection() string { return r.collectionName } func (r *MongoRepository) Insert(ctx context.Context, obj storable.Storable, getFilter builder.Query) error { if (obj.GetID() == nil) || (obj.GetID().IsZero()) { obj.SetID(primitive.NewObjectID()) } obj.Update() _, err := r.collection.InsertOne(ctx, obj) if mongo.IsDuplicateKeyError(err) { if getFilter != nil { if err = r.FindOneByFilter(ctx, getFilter, obj); err != nil { return err } } return merrors.DataConflict("duplicate_key") } return err } func (r *MongoRepository) InsertMany(ctx context.Context, objects []storable.Storable) error { if len(objects) == 0 { return nil } docs := make([]interface{}, len(objects)) for i, obj := range objects { if (obj.GetID() == nil) || (obj.GetID().IsZero()) { obj.SetID(primitive.NewObjectID()) } obj.Update() docs[i] = obj } _, err := r.collection.InsertMany(ctx, docs) return err } func (r *MongoRepository) findOneByFilterImp(ctx context.Context, filter bson.D, errMessage string, result storable.Storable) error { err := r.collection.FindOne(ctx, filter).Decode(result) if errors.Is(err, mongo.ErrNoDocuments) { return merrors.NoData(errMessage) } return err } func (r *MongoRepository) Get(ctx context.Context, id primitive.ObjectID, result storable.Storable) error { if id.IsZero() { return merrors.InvalidArgument("zero id provided while fetching "+result.Collection(), "id") } return r.findOneByFilterImp(ctx, idFilter(id), fmt.Sprintf("%s with ID = %s not found", result.Collection(), id.Hex()), result) } type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) func (r *MongoRepository) executeQuery(ctx context.Context, queryFunc QueryFunc, decoder rd.DecodingFunc) error { cursor, err := queryFunc(ctx, r.collection) if errors.Is(err, mongo.ErrNoDocuments) { return merrors.NoData("no_items_in_array") } if err != nil { return err } defer cursor.Close(ctx) for cursor.Next(ctx) { if err = decoder(cursor); err != nil { return err } } return nil } func (r *MongoRepository) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rd.DecodingFunc) error { queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) { return collection.Aggregate(ctx, pipeline.Build()) } return r.executeQuery(ctx, queryFunc, decoder) } func (r *MongoRepository) FindManyByFilter(ctx context.Context, query builder.Query, decoder rd.DecodingFunc) error { queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) { return collection.Find(ctx, query.BuildQuery(), query.BuildOptions()) } return r.executeQuery(ctx, queryFunc, decoder) } func (r *MongoRepository) FindOneByFilter(ctx context.Context, query builder.Query, result storable.Storable) error { return r.findOneByFilterImp(ctx, query.BuildQuery(), result.Collection()+" not found by filter", result) } func (r *MongoRepository) Update(ctx context.Context, obj storable.Storable) error { obj.Update() return r.collection.FindOneAndReplace(ctx, idFilter(*obj.GetID()), obj).Err() } func (r *MongoRepository) Patch(ctx context.Context, id primitive.ObjectID, patch builder.Patch) error { if id.IsZero() { return merrors.InvalidArgument("zero id provided while patching", "id") } _, err := r.collection.UpdateByID(ctx, id, patch.Build()) return err } func (r *MongoRepository) PatchMany(ctx context.Context, query builder.Query, patch builder.Patch) (int, error) { result, err := r.collection.UpdateMany(ctx, query.BuildQuery(), patch.Build()) if err != nil { return 0, err } return int(result.ModifiedCount), nil } func (r *MongoRepository) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) { filter := query.BuildQuery() findOptions := options.Find().SetProjection(bson.M{storable.IDField: 1}) cursor, err := r.collection.Find(ctx, filter, findOptions) if err != nil { return nil, err } defer cursor.Close(ctx) var ids []primitive.ObjectID for cursor.Next(ctx) { var doc struct { ID primitive.ObjectID `bson:"_id"` } if err := cursor.Decode(&doc); err != nil { return nil, err } ids = append(ids, doc.ID) } if err := cursor.Err(); err != nil { return nil, err } return ids, nil } func (r *MongoRepository) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) { filter := query.BuildQuery() findOptions := options.Find().SetProjection(bson.M{ storable.IDField: 1, storable.PermissionRefField: 1, storable.OrganizationRefField: 1, }) cursor, err := r.collection.Find(ctx, filter, findOptions) if err != nil { return nil, err } defer cursor.Close(ctx) result := make([]model.PermissionBoundStorable, 0) for cursor.Next(ctx) { var doc model.PermissionBound if err := cursor.Decode(&doc); err != nil { return nil, err } result = append(result, &doc) } if err := cursor.Err(); err != nil { return nil, err } return result, nil } func (r *MongoRepository) ListAccountBound(ctx context.Context, query builder.Query) ([]model.AccountBoundStorable, error) { filter := query.BuildQuery() findOptions := options.Find().SetProjection(bson.M{ storable.IDField: 1, model.AccountRefField: 1, model.OrganizationRefField: 1, }) cursor, err := r.collection.Find(ctx, filter, findOptions) if err != nil { return nil, err } defer cursor.Close(ctx) result := make([]model.AccountBoundStorable, 0) for cursor.Next(ctx) { var doc model.AccountBoundBase if err := cursor.Decode(&doc); err != nil { return nil, err } result = append(result, &doc) } if err := cursor.Err(); err != nil { return nil, err } return result, nil } func (r *MongoRepository) Delete(ctx context.Context, id primitive.ObjectID) error { _, err := r.collection.DeleteOne(ctx, idFilter(id)) return err } func (r *MongoRepository) DeleteMany(ctx context.Context, query builder.Query) error { _, err := r.collection.DeleteMany(ctx, query.BuildQuery()) return err } func (r *MongoRepository) Name() string { return r.collection.Name() }