320 lines
13 KiB
Go
320 lines
13 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/tech/sendico/pkg/db/policy"
|
|
"github.com/tech/sendico/pkg/db/repository"
|
|
"github.com/tech/sendico/pkg/db/repository/builder"
|
|
ri "github.com/tech/sendico/pkg/db/repository/index"
|
|
"github.com/tech/sendico/pkg/db/storable"
|
|
"github.com/tech/sendico/pkg/db/template"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
"github.com/tech/sendico/pkg/mlogger"
|
|
"github.com/tech/sendico/pkg/model"
|
|
"github.com/tech/sendico/pkg/mservice"
|
|
"github.com/tech/sendico/pkg/mutil/mzap"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type ProtectedDBImp[T model.PermissionBoundStorable] struct {
|
|
DBImp *template.DBImp[T]
|
|
Enforcer Enforcer
|
|
PermissionRef primitive.ObjectID
|
|
Collection mservice.Type
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) enforce(ctx context.Context, action model.Action, object model.PermissionBoundStorable, accountRef, objectRef primitive.ObjectID) error {
|
|
res, err := db.Enforcer.Enforce(ctx, object.GetPermissionRef(), accountRef, object.GetOrganizationRef(), objectRef, action)
|
|
if err != nil {
|
|
db.DBImp.Logger.Warn("Failed to enforce permission",
|
|
zap.Error(err), mzap.ObjRef("permission_ref", object.GetPermissionRef()),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
|
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
|
return err
|
|
}
|
|
if !res {
|
|
db.DBImp.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", object.GetPermissionRef()),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
|
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
|
return merrors.AccessDenied(db.Collection, string(action), objectRef)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error {
|
|
db.DBImp.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
|
|
|
if object.GetPermissionRef() == primitive.NilObjectID {
|
|
object.SetPermissionRef(db.PermissionRef)
|
|
}
|
|
object.SetOrganizationRef(organizationRef)
|
|
|
|
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.DBImp.Create(ctx, object); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error {
|
|
if len(objects) == 0 {
|
|
return nil
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Attempting to insert many objects", mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
|
zap.Int("count", len(objects)))
|
|
|
|
// Set permission and organization refs for all objects and enforce permissions
|
|
for _, object := range objects {
|
|
if object.GetPermissionRef() == primitive.NilObjectID {
|
|
object.SetPermissionRef(db.PermissionRef)
|
|
}
|
|
object.SetOrganizationRef(organizationRef)
|
|
|
|
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := db.DBImp.InsertMany(ctx, objects); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to insert many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
|
zap.Int("count", len(objects)))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully inserted many objects", mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
|
zap.Int("count", len(objects)))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) enforceObject(ctx context.Context, action model.Action, accountRef, objectRef primitive.ObjectID) error {
|
|
l, err := db.ListIDs(ctx, action, accountRef, repository.IDFilter(objectRef))
|
|
if err != nil {
|
|
db.DBImp.Logger.Warn("Error occured while checking access rights", zap.Error(err),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
|
return err
|
|
}
|
|
if len(l) == 0 {
|
|
db.DBImp.Logger.Debug("Access denied", zap.String("action", string(action)), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
return merrors.AccessDenied(db.Collection, string(action), objectRef)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error {
|
|
db.DBImp.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
|
|
if err := db.enforceObject(ctx, model.ActionRead, accountRef, objectRef); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.DBImp.Get(ctx, objectRef, result); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully retrieved object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", result.GetOrganizationRef()),
|
|
mzap.StorableRef(result), mzap.ObjRef("permission_ref", result.GetPermissionRef()))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error {
|
|
db.DBImp.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object))
|
|
|
|
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, *object.GetID()); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.DBImp.Update(ctx, object); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
|
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully updated object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
|
mzap.StorableRef(object), mzap.ObjRef("permission_ref", object.GetPermissionRef()))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
|
db.DBImp.Logger.Debug("Attempting to delete object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
|
|
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.DBImp.Delete(ctx, objectRef); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to delete object", zap.Error(err),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully deleted object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) ListIDs(
|
|
ctx context.Context,
|
|
action model.Action,
|
|
accountRef primitive.ObjectID,
|
|
query builder.Query,
|
|
) ([]primitive.ObjectID, error) {
|
|
db.DBImp.Logger.Debug("Attempting to list object IDs",
|
|
mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
|
|
|
|
// 1. Fetch all candidate IDs from the underlying DB
|
|
allIDs, err := db.DBImp.ListPermissionBound(ctx, query)
|
|
if err != nil {
|
|
db.DBImp.Logger.Warn("Failed to list object IDs", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
|
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
|
|
return nil, err
|
|
}
|
|
if len(allIDs) == 0 {
|
|
db.DBImp.Logger.Debug("No objects found matching filter", mzap.ObjRef("account_ref", accountRef),
|
|
zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
|
|
return []primitive.ObjectID{}, merrors.NoData(fmt.Sprintf("no %s found", db.Collection))
|
|
}
|
|
|
|
// 2. Check read permission for each ID
|
|
var allowedIDs []primitive.ObjectID
|
|
for _, desc := range allIDs {
|
|
enforceErr := db.enforce(ctx, action, desc, accountRef, *desc.GetID())
|
|
if enforceErr == nil {
|
|
allowedIDs = append(allowedIDs, *desc.GetID())
|
|
} else if !errors.Is(err, merrors.ErrAccessDenied) {
|
|
// If the error is something other than AccessDenied, we want to fail
|
|
db.DBImp.Logger.Warn("Error while enforcing read permission", zap.Error(enforceErr),
|
|
mzap.ObjRef("permission_ref", desc.GetPermissionRef()), zap.String("action", string(action)),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
|
|
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("collection", string(db.Collection)),
|
|
)
|
|
return nil, enforceErr
|
|
}
|
|
// If AccessDenied, we simply skip that ID.
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully enforced read permission on IDs", zap.Int("fetched_count", len(allIDs)),
|
|
zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef),
|
|
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
|
|
|
|
// 3. Return only the IDs that passed permission checks
|
|
return allowedIDs, nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Unprotected() template.DB[T] {
|
|
return db.DBImp
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
|
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
|
|
return err
|
|
}
|
|
if err := db.DBImp.DeleteCascade(ctx, objectRef); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to delete dependent object", zap.Error(err))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func CreateDBImp[T model.PermissionBoundStorable](
|
|
ctx context.Context,
|
|
l mlogger.Logger,
|
|
pdb policy.DB,
|
|
enforcer Enforcer,
|
|
collection mservice.Type,
|
|
db *mongo.Database,
|
|
) (*ProtectedDBImp[T], error) {
|
|
logger := l.Named("protected")
|
|
var policy model.PolicyDescription
|
|
if err := pdb.GetBuiltInPolicy(ctx, collection, &policy); err != nil {
|
|
logger.Warn("Failed to fetch policy description", zap.Error(err), zap.String("resource_type", string(collection)))
|
|
return nil, err
|
|
}
|
|
p := &ProtectedDBImp[T]{
|
|
DBImp: template.Create[T](logger, collection, db),
|
|
PermissionRef: policy.ID,
|
|
Collection: collection,
|
|
Enforcer: enforcer,
|
|
}
|
|
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
|
Keys: []ri.Key{{Field: storable.OrganizationRefField, Sort: ri.Asc}},
|
|
}); err != nil {
|
|
logger.Warn("Failed to create index", zap.Error(err), zap.String("resource_type", string(collection)))
|
|
return nil, err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
|
|
db.DBImp.Logger.Debug("Attempting to patch object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
|
|
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, objectRef); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := db.DBImp.Repository.Patch(ctx, objectRef, patch); err != nil {
|
|
db.DBImp.Logger.Warn("Failed to patch object", zap.Error(err),
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
return err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully patched object",
|
|
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
|
return nil
|
|
}
|
|
|
|
func (db *ProtectedDBImp[T]) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
|
db.DBImp.Logger.Debug("Attempting to patch many objects",
|
|
mzap.ObjRef("account_ref", accountRef), zap.Any("filter", query.BuildQuery()))
|
|
|
|
ids, err := db.ListIDs(ctx, model.ActionUpdate, accountRef, query)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if len(ids) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
values := make([]any, len(ids))
|
|
for i, id := range ids {
|
|
values[i] = id
|
|
}
|
|
idFilter := repository.Query().In(repository.IDField(), values...)
|
|
finalQuery := query.And(idFilter)
|
|
|
|
modified, err := db.DBImp.Repository.PatchMany(ctx, finalQuery, patch)
|
|
if err != nil {
|
|
db.DBImp.Logger.Warn("Failed to patch many objects", zap.Error(err),
|
|
mzap.ObjRef("account_ref", accountRef))
|
|
return 0, err
|
|
}
|
|
|
|
db.DBImp.Logger.Debug("Successfully patched many objects",
|
|
mzap.ObjRef("account_ref", accountRef), zap.Int("modified_count", modified))
|
|
return modified, nil
|
|
}
|