package casbin import ( "context" "github.com/tech/sendico/pkg/db/role" "github.com/tech/sendico/pkg/db/storable" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/mutil/mzap" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" ) // RoleManager manages roles using Casbin. type RoleManager struct { logger mlogger.Logger enforcer *CasbinEnforcer rdb role.DB rolePermissionRef primitive.ObjectID } // NewRoleManager creates a new RoleManager. func NewRoleManager(logger mlogger.Logger, enforcer *CasbinEnforcer, rolePermissionRef primitive.ObjectID, rdb role.DB) *RoleManager { return &RoleManager{ logger: logger.Named("role"), enforcer: enforcer, rdb: rdb, rolePermissionRef: rolePermissionRef, } } // validateObjectIDs ensures that all provided ObjectIDs are non-zero. func (rm *RoleManager) validateObjectIDs(ids ...primitive.ObjectID) error { for _, id := range ids { if id.IsZero() { return merrors.InvalidArgument("Object references cannot be zero", "objectRef") } } return nil } // removePolicies removes policies based on the provided filter and logs the results. func (rm *RoleManager) removePolicies(policyType, role string, roleRef primitive.ObjectID) error { filterIndex := 1 if policyType == "permission" { filterIndex = 0 } policies, err := rm.enforcer.enforcer.GetFilteredPolicy(filterIndex, role) if err != nil { rm.logger.Warn("Failed to fetch "+policyType+" policies", zap.Error(err), mzap.ObjRef("role_ref", roleRef)) return err } for _, policy := range policies { args := make([]any, len(policy)) for i, v := range policy { args[i] = v } var removed bool var removeErr error if policyType == "grouping" { removed, removeErr = rm.enforcer.enforcer.RemoveGroupingPolicy(args...) } else { removed, removeErr = rm.enforcer.enforcer.RemovePolicy(args...) } if removeErr != nil { rm.logger.Warn("Failed to remove "+policyType+" policy for role", zap.Error(removeErr), mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy)) return removeErr } if removed { rm.logger.Info("Removed "+policyType+" policy for role", mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy)) } } return nil } // fetchRolesFromPolicies retrieves and converts policies to roles. func (rm *RoleManager) fetchRolesFromPolicies(policies [][]string, orgRef primitive.ObjectID) []model.RoleDescription { roles := make([]model.RoleDescription, 0, len(policies)) for _, policy := range policies { if len(policy) < 2 { continue } roleID, err := primitive.ObjectIDFromHex(policy[1]) if err != nil { rm.logger.Warn("Invalid role ID", zap.String("roleID", policy[1])) continue } roles = append(roles, model.RoleDescription{Base: storable.Base{ID: roleID}, OrganizationRef: orgRef}) } return roles } // Create creates a new role in an organization. func (rm *RoleManager) Create(ctx context.Context, orgRef primitive.ObjectID, description *model.Describable) (*model.RoleDescription, error) { if err := rm.validateObjectIDs(orgRef); err != nil { return nil, err } role := &model.RoleDescription{ Describable: *description, OrganizationRef: orgRef, } if err := rm.rdb.Create(ctx, role); err != nil { rm.logger.Warn("Failed to create role", zap.Error(err), mzap.ObjRef("organiztion_ref", orgRef)) return nil, err } rm.logger.Info("Role created successfully", mzap.StorableRef(role), mzap.ObjRef("organization_ref", orgRef)) return role, nil } // Assign assigns a role to a user in the given organization. func (rm *RoleManager) Assign(ctx context.Context, role *model.Role) error { if err := rm.validateObjectIDs(role.DescriptionRef, role.AccountRef, role.OrganizationRef); err != nil { return err } sub := role.AccountRef.Hex() roleID := role.DescriptionRef.Hex() domain := role.OrganizationRef.Hex() added, err := rm.enforcer.enforcer.AddGroupingPolicy(sub, roleID, domain) return rm.logPolicyResult("assign", added, err, role.DescriptionRef, role.AccountRef, role.OrganizationRef) } // Delete removes a role entirely and cleans up associated Casbin policies. func (rm *RoleManager) Delete(ctx context.Context, roleRef primitive.ObjectID) error { if err := rm.validateObjectIDs(roleRef); err != nil { rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef)) return err } if err := rm.rdb.Delete(ctx, roleRef); err != nil { rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef)) return err } role := roleRef.Hex() // Remove grouping policies if err := rm.removePolicies("grouping", role, roleRef); err != nil { return err } // Remove permission policies if err := rm.removePolicies("permission", role, roleRef); err != nil { return err } // // Save changes // if err := rm.enforcer.enforcer.SavePolicy(); err != nil { // rm.logger.Warn("Failed to save Casbin policies after role deletion", // zap.Error(err), // mzap.ObjRef("role_ref", roleRef), // ) // return err // } rm.logger.Info("Role deleted successfully along with associated policies", mzap.ObjRef("role_ref", roleRef)) return nil } // Revoke removes a role from a user. func (rm *RoleManager) Revoke(ctx context.Context, roleRef, accountRef, orgRef primitive.ObjectID) error { if err := rm.validateObjectIDs(roleRef, accountRef, orgRef); err != nil { return err } sub := accountRef.Hex() role := roleRef.Hex() domain := orgRef.Hex() removed, err := rm.enforcer.enforcer.RemoveGroupingPolicy(sub, role, domain) return rm.logPolicyResult("revoke", removed, err, roleRef, accountRef, orgRef) } // logPolicyResult logs results for Assign and Revoke. func (rm *RoleManager) logPolicyResult(action string, result bool, err error, roleRef, accountRef, orgRef primitive.ObjectID) error { if err != nil { rm.logger.Warn("Failed to "+action+" role", zap.Error(err), mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef)) return err } msg := "Role " + action + "ed successfully" if !result { msg = "Role already " + action + "ed" } rm.logger.Info(msg, mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef)) return nil } // List retrieves all roles in an organization or all roles if orgRef is zero. func (rm *RoleManager) List(ctx context.Context, orgRef primitive.ObjectID) ([]model.RoleDescription, error) { domain := orgRef.Hex() groupingPolicies, err := rm.enforcer.enforcer.GetFilteredGroupingPolicy(2, domain) if err != nil { rm.logger.Warn("Failed to fetch grouping policies", zap.Error(err), mzap.ObjRef("organization_ref", orgRef)) return nil, err } roles := rm.fetchRolesFromPolicies(groupingPolicies, orgRef) rm.logger.Info("Retrieved roles for organization", mzap.ObjRef("organization_ref", orgRef), zap.Int("count", len(roles))) return roles, nil }