207 lines
7.4 KiB
Go
207 lines
7.4 KiB
Go
// casbin_enforcer.go
|
||
package casbin
|
||
|
||
import (
|
||
"context"
|
||
|
||
"github.com/casbin/casbin/v2"
|
||
"github.com/tech/sendico/pkg/auth/anyobject"
|
||
cc "github.com/tech/sendico/pkg/auth/internal/casbin/config"
|
||
"github.com/tech/sendico/pkg/auth/internal/casbin/serialization"
|
||
"github.com/tech/sendico/pkg/merrors"
|
||
"github.com/tech/sendico/pkg/mlogger"
|
||
"github.com/tech/sendico/pkg/model"
|
||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||
"github.com/mitchellh/mapstructure"
|
||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// CasbinEnforcer implements the Enforcer interface using Casbin.
|
||
type CasbinEnforcer struct {
|
||
logger mlogger.Logger
|
||
enforcer *casbin.Enforcer
|
||
roleSerializer serialization.Role
|
||
permissionSerializer serialization.Policy
|
||
}
|
||
|
||
// NewCasbinEnforcer initializes a new CasbinEnforcer with a MongoDB adapter, logger, and PolicySerializer.
|
||
// The 'domain' parameter is no longer stored internally, as the interface requires passing a domainRef per method call.
|
||
func NewEnforcer(
|
||
logger mlogger.Logger,
|
||
client *mongo.Client,
|
||
settings model.SettingsT,
|
||
) (*CasbinEnforcer, error) {
|
||
var config cc.Config
|
||
if err := mapstructure.Decode(settings, &config); err != nil {
|
||
logger.Warn("Failed to decode Casbin configuration", zap.Error(err), zap.Any("settings", settings))
|
||
return nil, merrors.Internal("failed to decode Casbin configuration")
|
||
}
|
||
|
||
// Create a Casbin adapter + enforcer from your config and client.
|
||
l := logger.Named("enforcer")
|
||
e, err := createAdapter(l, &config, client)
|
||
if err != nil {
|
||
logger.Warn("Failed to create Casbin enforcer", zap.Error(err))
|
||
return nil, merrors.Internal("failed to create Casbin enforcer")
|
||
}
|
||
|
||
logger.Info("Casbin enforcer created")
|
||
return &CasbinEnforcer{
|
||
logger: l,
|
||
enforcer: e,
|
||
permissionSerializer: serialization.NewPolicySerializer(),
|
||
roleSerializer: serialization.NewRoleSerializer(),
|
||
}, nil
|
||
}
|
||
|
||
// Enforce checks if a user has the specified action permission on an object within a domain.
|
||
func (c *CasbinEnforcer) Enforce(
|
||
_ context.Context,
|
||
permissionRef, accountRef, organizationRef, objectRef primitive.ObjectID,
|
||
action model.Action,
|
||
) (bool, error) {
|
||
// Convert ObjectIDs to strings for Casbin
|
||
account := accountRef.Hex()
|
||
organization := organizationRef.Hex()
|
||
permission := permissionRef.Hex()
|
||
object := anyobject.ID
|
||
if objectRef != primitive.NilObjectID {
|
||
object = objectRef.Hex()
|
||
}
|
||
act := string(action)
|
||
|
||
c.logger.Debug("Enforcing policy",
|
||
zap.String("account", account), zap.String("organization", organization),
|
||
zap.String("permission", permission), zap.String("object", object),
|
||
zap.String("action", act))
|
||
|
||
// Perform the enforcement
|
||
result, err := c.enforcer.Enforce(account, organization, permission, object, act)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to enforce policy", zap.Error(err),
|
||
zap.String("account", account), zap.String("organization", organization),
|
||
zap.String("permission", permission), zap.String("object", object),
|
||
zap.String("action", act))
|
||
return false, err
|
||
}
|
||
|
||
c.logger.Debug("Policy enforcement result", zap.Bool("result", result))
|
||
return result, nil
|
||
}
|
||
|
||
// EnforceBatch checks a user’s permission for multiple objects at once.
|
||
// It returns a map from objectRef -> boolean indicating whether access is granted.
|
||
func (c *CasbinEnforcer) EnforceBatch(
|
||
ctx context.Context,
|
||
objectRefs []model.PermissionBoundStorable,
|
||
accountRef primitive.ObjectID,
|
||
action model.Action,
|
||
) (map[primitive.ObjectID]bool, error) {
|
||
results := make(map[primitive.ObjectID]bool, len(objectRefs))
|
||
for _, desc := range objectRefs {
|
||
ok, err := c.Enforce(ctx, desc.GetPermissionRef(), accountRef, desc.GetOrganizationRef(), *desc.GetID(), action)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to enforce", zap.Error(err), mzap.ObjRef("permission_ref", desc.GetPermissionRef()),
|
||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
|
||
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("action", string(action)))
|
||
return nil, err
|
||
}
|
||
results[*desc.GetID()] = ok
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// GetRoles retrieves all roles assigned to the user within the domain.
|
||
func (c *CasbinEnforcer) GetRoles(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, error) {
|
||
sub := accountRef.Hex()
|
||
dom := orgRef.Hex()
|
||
|
||
c.logger.Debug("Fetching roles for user", zap.String("subject", sub), zap.String("domain", dom))
|
||
|
||
// Get all roles for the user in the domain
|
||
sroles, err := c.enforcer.GetFilteredGroupingPolicy(0, sub, "", dom)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to get roles from policies", zap.Error(err),
|
||
zap.String("account_ref", sub), zap.String("organization_ref", dom),
|
||
)
|
||
return nil, merrors.Internal("failed to fetch roles from policies")
|
||
}
|
||
|
||
roles := make([]model.Role, 0, len(sroles))
|
||
for _, srole := range sroles {
|
||
role, err := c.roleSerializer.Deserialize(srole)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to deserialize role", zap.Error(err))
|
||
return nil, err
|
||
}
|
||
roles = append(roles, *role)
|
||
}
|
||
|
||
c.logger.Debug("Roles fetched successfully", zap.Int("count", len(roles)))
|
||
return roles, nil
|
||
}
|
||
|
||
// GetPermissions retrieves all effective policies for the user within the domain.
|
||
func (c *CasbinEnforcer) GetPermissions(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
|
||
c.logger.Debug("Fetching policies for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
|
||
|
||
// Step 1: Retrieve all roles assigned to the user within the domain
|
||
roles, err := c.GetRoles(ctx, accountRef, orgRef)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to get roles", zap.Error(err))
|
||
return nil, nil, err
|
||
}
|
||
|
||
// Map to hold unique policies
|
||
permissionsMap := make(map[string]*model.Permission)
|
||
for _, role := range roles {
|
||
// Step 2a: Retrieve all policies associated with the role within the domain
|
||
policies, err := c.enforcer.GetFilteredPolicy(0, role.DescriptionRef.Hex())
|
||
if err != nil {
|
||
c.logger.Warn("Failed to get policies for role", zap.Error(err), mzap.ObjRef("role_ref", role.DescriptionRef))
|
||
continue
|
||
}
|
||
|
||
// Step 2b: Process each policy to extract Permission, Action, and Effect
|
||
for _, policy := range policies {
|
||
|
||
if len(policy) < 5 {
|
||
c.logger.Warn("Incomplete policy encountered", zap.Strings("policy", policy))
|
||
continue // Ensure the policy line has enough fields
|
||
}
|
||
|
||
// Deserialize the policy using
|
||
deserializedPolicy, err := c.permissionSerializer.Deserialize(policy)
|
||
if err != nil {
|
||
c.logger.Warn("Failed to deserialize policy", zap.Error(err), zap.Strings("policy", policy))
|
||
continue
|
||
}
|
||
|
||
// Construct a unique key combining Permission ID and Action to prevent duplicates
|
||
policyKey := deserializedPolicy.DescriptionRef.Hex() + ":" + string(deserializedPolicy.Effect.Action)
|
||
if _, exists := permissionsMap[policyKey]; exists {
|
||
continue // Policy-action pair already accounted for
|
||
}
|
||
|
||
// Add the Policy to the map
|
||
permissionsMap[policyKey] = &model.Permission{
|
||
RolePolicy: *deserializedPolicy,
|
||
AccountRef: accountRef,
|
||
}
|
||
c.logger.Debug("Policy added to policyMap", zap.Any("policy_key", policyKey))
|
||
}
|
||
}
|
||
|
||
// Convert the map to a slice
|
||
permissions := make([]model.Permission, 0, len(permissionsMap))
|
||
for _, permission := range permissionsMap {
|
||
permissions = append(permissions, *permission)
|
||
}
|
||
|
||
c.logger.Debug("Permissions fetched successfully", zap.Int("count", len(permissions)))
|
||
return roles, permissions, nil
|
||
}
|