package auth import ( "context" "errors" "fmt" "github.com/tech/sendico/pkg/db/policy" "github.com/tech/sendico/pkg/db/repository" "github.com/tech/sendico/pkg/db/repository/builder" ri "github.com/tech/sendico/pkg/db/repository/index" "github.com/tech/sendico/pkg/db/storable" "github.com/tech/sendico/pkg/db/template" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/mservice" "github.com/tech/sendico/pkg/mutil/mzap" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.uber.org/zap" ) type ProtectedDBImp[T model.PermissionBoundStorable] struct { DBImp *template.DBImp[T] Enforcer Enforcer PermissionRef primitive.ObjectID Collection mservice.Type } func (db *ProtectedDBImp[T]) enforce(ctx context.Context, action model.Action, object model.PermissionBoundStorable, accountRef, objectRef primitive.ObjectID) error { res, err := db.Enforcer.Enforce(ctx, object.GetPermissionRef(), accountRef, object.GetOrganizationRef(), objectRef, action) if err != nil { db.DBImp.Logger.Warn("Failed to enforce permission", zap.Error(err), mzap.ObjRef("permission_ref", object.GetPermissionRef()), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action))) return err } if !res { db.DBImp.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", object.GetPermissionRef()), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action))) return merrors.AccessDenied(db.Collection, string(action), objectRef) } return nil } func (db *ProtectedDBImp[T]) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error { db.DBImp.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection))) if object.GetPermissionRef() == primitive.NilObjectID { object.SetPermissionRef(db.PermissionRef) } object.SetOrganizationRef(organizationRef) if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil { return err } if err := db.DBImp.Create(ctx, object); err != nil { db.DBImp.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection))) return err } db.DBImp.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection))) return nil } func (db *ProtectedDBImp[T]) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error { if len(objects) == 0 { return nil } db.DBImp.Logger.Debug("Attempting to insert many objects", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)), zap.Int("count", len(objects))) // Set permission and organization refs for all objects and enforce permissions for _, object := range objects { if object.GetPermissionRef() == primitive.NilObjectID { object.SetPermissionRef(db.PermissionRef) } object.SetOrganizationRef(organizationRef) if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil { return err } } if err := db.DBImp.InsertMany(ctx, objects); err != nil { db.DBImp.Logger.Warn("Failed to insert many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)), zap.Int("count", len(objects))) return err } db.DBImp.Logger.Debug("Successfully inserted many objects", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)), zap.Int("count", len(objects))) return nil } func (db *ProtectedDBImp[T]) enforceObject(ctx context.Context, action model.Action, accountRef, objectRef primitive.ObjectID) error { l, err := db.ListIDs(ctx, action, accountRef, repository.IDFilter(objectRef)) if err != nil { db.DBImp.Logger.Warn("Error occured while checking access rights", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action))) return err } if len(l) == 0 { db.DBImp.Logger.Debug("Access denied", zap.String("action", string(action)), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) return merrors.AccessDenied(db.Collection, string(action), objectRef) } return nil } func (db *ProtectedDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error { db.DBImp.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) if err := db.enforceObject(ctx, model.ActionRead, accountRef, objectRef); err != nil { return err } if err := db.DBImp.Get(ctx, objectRef, result); err != nil { db.DBImp.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection))) return err } db.DBImp.Logger.Debug("Successfully retrieved object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", result.GetOrganizationRef()), mzap.StorableRef(result), mzap.ObjRef("permission_ref", result.GetPermissionRef())) return nil } func (db *ProtectedDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error { db.DBImp.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object)) if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, *object.GetID()); err != nil { return err } if err := db.DBImp.Update(ctx, object); err != nil { db.DBImp.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object)) return err } db.DBImp.Logger.Debug("Successfully updated object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object), mzap.ObjRef("permission_ref", object.GetPermissionRef())) return nil } func (db *ProtectedDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error { db.DBImp.Logger.Debug("Attempting to delete object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil { return err } if err := db.DBImp.Delete(ctx, objectRef); err != nil { db.DBImp.Logger.Warn("Failed to delete object", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) return err } db.DBImp.Logger.Debug("Successfully deleted object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) return nil } func (db *ProtectedDBImp[T]) ListIDs( ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query, ) ([]primitive.ObjectID, error) { db.DBImp.Logger.Debug("Attempting to list object IDs", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery())) // 1. Fetch all candidate IDs from the underlying DB allIDs, err := db.DBImp.ListPermissionBound(ctx, query) if err != nil { db.DBImp.Logger.Warn("Failed to list object IDs", zap.Error(err), mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.String("action", string(action))) return nil, err } if len(allIDs) == 0 { db.DBImp.Logger.Debug("No objects found matching filter", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery())) return []primitive.ObjectID{}, merrors.NoData(fmt.Sprintf("no %s found", db.Collection)) } // 2. Check read permission for each ID var allowedIDs []primitive.ObjectID for _, desc := range allIDs { enforceErr := db.enforce(ctx, action, desc, accountRef, *desc.GetID()) if enforceErr == nil { allowedIDs = append(allowedIDs, *desc.GetID()) } else if !errors.Is(err, merrors.ErrAccessDenied) { // If the error is something other than AccessDenied, we want to fail db.DBImp.Logger.Warn("Error while enforcing read permission", zap.Error(enforceErr), mzap.ObjRef("permission_ref", desc.GetPermissionRef()), zap.String("action", string(action)), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()), mzap.ObjRef("object_ref", *desc.GetID()), zap.String("collection", string(db.Collection)), ) return nil, enforceErr } // If AccessDenied, we simply skip that ID. } db.DBImp.Logger.Debug("Successfully enforced read permission on IDs", zap.Int("fetched_count", len(allIDs)), zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.String("action", string(action))) // 3. Return only the IDs that passed permission checks return allowedIDs, nil } func (db *ProtectedDBImp[T]) Unprotected() template.DB[T] { return db.DBImp } func (db *ProtectedDBImp[T]) DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error { if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil { return err } if err := db.DBImp.DeleteCascade(ctx, objectRef); err != nil { db.DBImp.Logger.Warn("Failed to delete dependent object", zap.Error(err)) return err } return nil } func CreateDBImp[T model.PermissionBoundStorable]( ctx context.Context, l mlogger.Logger, pdb policy.DB, enforcer Enforcer, collection mservice.Type, db *mongo.Database, ) (*ProtectedDBImp[T], error) { logger := l.Named("protected") var policy model.PolicyDescription if err := pdb.GetBuiltInPolicy(ctx, collection, &policy); err != nil { logger.Warn("Failed to fetch policy description", zap.Error(err), zap.String("resource_type", string(collection))) return nil, err } p := &ProtectedDBImp[T]{ DBImp: template.Create[T](logger, collection, db), PermissionRef: policy.ID, Collection: collection, Enforcer: enforcer, } if err := p.DBImp.Repository.CreateIndex(&ri.Definition{ Keys: []ri.Key{{Field: storable.OrganizationRefField, Sort: ri.Asc}}, }); err != nil { logger.Warn("Failed to create index", zap.Error(err), zap.String("resource_type", string(collection))) return nil, err } return p, nil } func (db *ProtectedDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error { db.DBImp.Logger.Debug("Attempting to patch object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, objectRef); err != nil { return err } if err := db.DBImp.Repository.Patch(ctx, objectRef, patch); err != nil { db.DBImp.Logger.Warn("Failed to patch object", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) return err } db.DBImp.Logger.Debug("Successfully patched object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef)) return nil } func (db *ProtectedDBImp[T]) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) { db.DBImp.Logger.Debug("Attempting to patch many objects", mzap.ObjRef("account_ref", accountRef), zap.Any("filter", query.BuildQuery())) ids, err := db.ListIDs(ctx, model.ActionUpdate, accountRef, query) if err != nil { return 0, err } if len(ids) == 0 { return 0, nil } values := make([]any, len(ids)) for i, id := range ids { values[i] = id } idFilter := repository.Query().In(repository.IDField(), values...) finalQuery := query.And(idFilter) modified, err := db.DBImp.Repository.PatchMany(ctx, finalQuery, patch) if err != nil { db.DBImp.Logger.Warn("Failed to patch many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef)) return 0, err } db.DBImp.Logger.Debug("Successfully patched many objects", mzap.ObjRef("account_ref", accountRef), zap.Int("modified_count", modified)) return modified, nil }