service backend
All checks were successful
ci/woodpecker/push/db Pipeline was successful
ci/woodpecker/push/nats Pipeline was successful

This commit is contained in:
Stephan D
2025-11-07 18:35:26 +01:00
parent 20e8f9acc4
commit 62a6631b9a
537 changed files with 48453 additions and 0 deletions

BIN
api/pkg/.DS_Store vendored Normal file

Binary file not shown.

6
api/pkg/.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
proto/billing
proto/common
proto/chain
proto/ledger
proto/oracle
proto/payments

View File

@@ -0,0 +1,36 @@
package api
import "fmt"
type HTTPMethod int
const (
Get HTTPMethod = iota
Post
Put
Patch
Delete
Options
Head
)
func HTTPMethod2String(method HTTPMethod) string {
switch method {
case Get:
return "GET"
case Post:
return "POST"
case Put:
return "PUT"
case Delete:
return "DELETE"
case Patch:
return "PATCH"
case Options:
return "OPTIONS"
case Head:
return "HEAD"
default:
return fmt.Sprintf("unknown: %d", method)
}
}

View File

@@ -0,0 +1,205 @@
package response
import (
"encoding/json"
"errors"
"fmt"
"net/http"
api "github.com/tech/sendico/pkg/api/http"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/mservice"
"go.uber.org/zap"
)
// BaseResponse is a general structure for all API responses.
type BaseResponse struct {
Status string `json:"status"` // "success" or "error"
Data any `json:"data"` // The actual data payload or the error details
}
// ErrorResponse provides more details about an error.
type ErrorResponse struct {
Code int `json:"code"` // A unique identifier for the error type, useful for client handling
Error string `json:"error"`
Source string `json:"source"`
Details string `json:"details"` // Additional details or hints about the error, if necessary
}
func errMessage(err error) string {
if err != nil {
return err.Error()
}
return ""
}
func logRequest(logger mlogger.Logger, r *http.Request, message string) {
logger.Debug(
message,
zap.String("host", r.Host),
zap.String("address", r.RemoteAddr),
zap.String("method", r.Method),
zap.String("request_uri", r.RequestURI),
zap.String("proto", r.Proto),
zap.String("user_agent", r.UserAgent()),
)
}
func writeJSON(logger mlogger.Logger, w http.ResponseWriter, r *http.Request, code int, payload any) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(&payload); err != nil {
logger.Warn("Failed to encode JSON response",
zap.Error(err),
zap.Any("response", payload),
zap.String("host", r.Host),
zap.String("address", r.RemoteAddr),
zap.String("method", r.Method),
zap.String("request_uri", r.RequestURI),
zap.String("proto", r.Proto),
zap.String("user_agent", r.UserAgent()))
}
}
func errorf(
logger mlogger.Logger,
w http.ResponseWriter, r *http.Request,
source mservice.Type, code int, message, details string,
) {
logRequest(logger, r, message)
errorMessage := BaseResponse{
Status: api.MSError,
Data: ErrorResponse{
Code: code,
Details: details,
Source: source,
Error: message,
},
}
writeJSON(logger, w, r, code, errorMessage)
}
func Accepted(logger mlogger.Logger, data any) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
resp := BaseResponse{
Status: api.MSProcessed,
Data: data,
}
writeJSON(logger, w, r, http.StatusAccepted, resp)
}
}
func Ok(logger mlogger.Logger, data any) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
resp := BaseResponse{
Status: api.MSSuccess,
Data: data,
}
writeJSON(logger, w, r, http.StatusOK, resp)
}
}
func Created(logger mlogger.Logger, data any) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
resp := BaseResponse{
Status: api.MSSuccess,
Data: data,
}
writeJSON(logger, w, r, http.StatusCreated, resp)
}
}
func Auto(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
if err == nil {
return Success(logger)
}
if errors.Is(err, merrors.ErrAccessDenied) {
return AccessDenied(logger, source, errMessage(err))
}
if errors.Is(err, merrors.ErrDataConflict) {
return DataConflict(logger, source, errMessage(err))
}
if errors.Is(err, merrors.ErrInvalidArg) {
return BadRequest(logger, source, "invalid_argument", errMessage(err))
}
if errors.Is(err, merrors.ErrNoData) {
return NotFound(logger, source, errMessage(err))
}
if errors.Is(err, merrors.ErrUnauthorized) {
return Unauthorized(logger, source, errMessage(err))
}
return Internal(logger, source, err)
}
func Internal(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusInternalServerError, "internal_error", errMessage(err))
}
}
func NotImplemented(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusNotImplemented, "not_implemented", hint)
}
}
func BadRequest(logger mlogger.Logger, source mservice.Type, err, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusBadRequest, err, hint)
}
}
func BadQueryParam(logger mlogger.Logger, source mservice.Type, param string, err error) http.HandlerFunc {
return BadRequest(logger, source, "invalid_query_parameter", fmt.Sprintf("Failed to parse '%s': %v", param, err))
}
func BadReference(logger mlogger.Logger, source mservice.Type, refName, refVal string, err error) http.HandlerFunc {
return BadRequest(logger, source, "broken_reference",
fmt.Sprintf("broken object reference: %s = %s, error: %v", refName, refVal, err))
}
func BadPayload(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
return BadRequest(logger, source, "broken_payload",
fmt.Sprintf("broken '%s' object payload, error: %v", source, err))
}
func DataConflict(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusConflict, "data_conflict", hint)
}
}
func Error(logger mlogger.Logger, source mservice.Type, code int, errType, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, code, errType, hint)
}
}
func AccessDenied(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return Error(logger, source, http.StatusForbidden, "access_denied", hint)
}
func Forbidden(logger mlogger.Logger, source mservice.Type, errType, hint string) http.HandlerFunc {
return Error(logger, source, http.StatusForbidden, errType, hint)
}
func LicenseRequired(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusPaymentRequired, "license_required", hint)
}
}
func Unauthorized(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusUnauthorized, "unauthorized", hint)
}
}
func NotFound(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
errorf(logger, w, r, source, http.StatusNotFound, "not_found", hint)
}
}

View File

@@ -0,0 +1,19 @@
package response
import (
"net/http"
"github.com/tech/sendico/pkg/mlogger"
)
type Result struct {
Result bool `json:"result"`
}
func Success(logger mlogger.Logger) http.HandlerFunc {
return Ok(logger, Result{Result: true})
}
func Failed(logger mlogger.Logger) http.HandlerFunc {
return Accepted(logger, Result{Result: false})
}

View File

@@ -0,0 +1,8 @@
package api
const (
MSSuccess string = "success"
MSProcessed string = "processed"
MSError string = "error"
MSRequest string = "request"
)

View File

@@ -0,0 +1,61 @@
package routers
import (
"context"
"net"
"github.com/tech/sendico/pkg/api/routers/internal/grpcimp"
"github.com/tech/sendico/pkg/mlogger"
"google.golang.org/grpc"
)
type (
GRPCServiceRegistration = func(grpc.ServiceRegistrar)
)
type GRPC interface {
Register(registration GRPCServiceRegistration) error
Start(ctx context.Context) error
Finish(ctx context.Context) error
Addr() net.Addr
Done() <-chan error
}
type (
GRPCConfig = grpcimp.Config
GRPCTLSConfig = grpcimp.TLSConfig
)
type GRPCOption func(*grpcimp.Options)
func WithUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) GRPCOption {
return func(o *grpcimp.Options) {
o.UnaryInterceptors = append(o.UnaryInterceptors, interceptors...)
}
}
func WithStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) GRPCOption {
return func(o *grpcimp.Options) {
o.StreamInterceptors = append(o.StreamInterceptors, interceptors...)
}
}
func WithListener(listener net.Listener) GRPCOption {
return func(o *grpcimp.Options) {
o.Listener = listener
}
}
func WithServerOptions(opts ...grpc.ServerOption) GRPCOption {
return func(o *grpcimp.Options) {
o.ServerOptions = append(o.ServerOptions, opts...)
}
}
func NewGRPCRouter(logger mlogger.Logger, config *GRPCConfig, opts ...GRPCOption) (GRPC, error) {
options := &grpcimp.Options{}
for _, opt := range opts {
opt(options)
}
return grpcimp.NewRouter(logger, config, options)
}

View File

@@ -0,0 +1,149 @@
package gsresponse
import (
"context"
"errors"
"fmt"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/mservice"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Responder produces a response or a gRPC status error when executed.
type Responder[T any] func(ctx context.Context) (*T, error)
func message(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func Success[T any](resp *T) Responder[T] {
return func(context.Context) (*T, error) {
return resp, nil
}
}
func Empty[T any]() Responder[T] {
return func(context.Context) (*T, error) {
return nil, nil
}
}
func Error[T any](logger mlogger.Logger, service mservice.Type, code codes.Code, hint string, err error) Responder[T] {
return func(ctx context.Context) (*T, error) {
fields := []zap.Field{
zap.String("service", string(service)),
zap.String("status_code", code.String()),
}
if hint != "" {
fields = append(fields, zap.String("error_hint", hint))
}
if err != nil {
fields = append(fields, zap.Error(err))
}
logFn := logger.Warn
switch code {
case codes.Internal, codes.DataLoss, codes.Unavailable:
logFn = logger.Error
}
logFn("gRPC request failed", fields...)
msg := message(err)
switch {
case hint == "" && msg == "":
return nil, status.Error(code, code.String())
case hint == "":
return nil, status.Error(code, msg)
case msg == "":
return nil, status.Error(code, hint)
default:
return nil, status.Error(code, fmt.Sprintf("%s: %s", hint, msg))
}
}
}
func Internal[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.Internal, "internal_error", err)
}
func InvalidArgument[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.InvalidArgument, "invalid_argument", err)
}
func NotFound[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.NotFound, "not_found", err)
}
func Unauthorized[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.Unauthenticated, "unauthorized", err)
}
func PermissionDenied[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.PermissionDenied, "access_denied", err)
}
func FailedPrecondition[T any](logger mlogger.Logger, service mservice.Type, hint string, err error) Responder[T] {
return Error[T](logger, service, codes.FailedPrecondition, hint, err)
}
func Conflict[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.Aborted, "conflict", err)
}
func DeadlineExceeded[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.DeadlineExceeded, "deadline_exceeded", err)
}
func Unavailable[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.Unavailable, "service_unavailable", err)
}
func Unimplemented[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.Unimplemented, "not_implemented", err)
}
func AlreadyExists[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
return Error[T](logger, service, codes.AlreadyExists, "already_exists", err)
}
func Auto[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
switch {
case err == nil:
return Empty[T]()
case errors.Is(err, merrors.ErrInvalidArg):
return InvalidArgument[T](logger, service, err)
case errors.Is(err, merrors.ErrAccessDenied):
return PermissionDenied[T](logger, service, err)
case errors.Is(err, merrors.ErrNoData):
return NotFound[T](logger, service, err)
case errors.Is(err, merrors.ErrUnauthorized):
return Unauthorized[T](logger, service, err)
case errors.Is(err, merrors.ErrDataConflict):
return Conflict[T](logger, service, err)
default:
return Internal[T](logger, service, err)
}
}
func Execute[T any](ctx context.Context, responder Responder[T]) (*T, error) {
if responder == nil {
return nil, status.Error(codes.Internal, "missing responder")
}
return responder(ctx)
}
func Unary[TReq any, TResp any](logger mlogger.Logger, service mservice.Type, handler func(context.Context, *TReq) Responder[TResp]) func(context.Context, *TReq) (*TResp, error) {
return func(ctx context.Context, req *TReq) (*TResp, error) {
if handler == nil {
return nil, status.Error(codes.Internal, "missing handler")
}
responder := handler(ctx, req)
return Execute(ctx, responder)
}
}

View File

@@ -0,0 +1,75 @@
package gsresponse
import (
"context"
"errors"
"fmt"
"testing"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mservice"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type testRequest struct {
Value string
}
type testResponse struct {
Result string
}
func TestUnarySuccess(t *testing.T) {
logger := zap.NewNop()
handler := func(ctx context.Context, req *testRequest) Responder[testResponse] {
require.NotNil(t, req)
require.Equal(t, "hello", req.Value)
resp := &testResponse{Result: "ok"}
return Success(resp)
}
unary := Unary[testRequest, testResponse](logger, mservice.Type("test"), handler)
resp, err := unary(context.Background(), &testRequest{Value: "hello"})
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, "ok", resp.Result)
}
func TestAutoMappings(t *testing.T) {
logger := zap.NewNop()
service := mservice.Type("test")
tests := []struct {
name string
err error
code codes.Code
}{
{"invalid_argument", merrors.InvalidArgument("bad"), codes.InvalidArgument},
{"access_denied", merrors.AccessDenied("object", "action", primitive.NilObjectID), codes.PermissionDenied},
{"not_found", merrors.NoData("missing"), codes.NotFound},
{"unauthorized", fmt.Errorf("%w: %s", merrors.ErrUnauthorized, "bad"), codes.Unauthenticated},
{"conflict", merrors.DataConflict("conflict"), codes.Aborted},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
responder := Auto[testResponse](logger, service, tc.err)
_, err := responder(context.Background())
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tc.code, st.Code())
})
}
responder := Auto[testResponse](logger, service, errors.New("boom"))
_, err := responder(context.Background())
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.Internal, st.Code())
}

View File

@@ -0,0 +1,17 @@
package routers
import (
"github.com/go-chi/chi/v5"
"github.com/tech/sendico/pkg/api/routers/health"
"github.com/tech/sendico/pkg/api/routers/internal/healthimp"
"github.com/tech/sendico/pkg/mlogger"
)
type Health interface {
SetStatus(status health.ServiceStatus)
Finish()
}
func NewHealthRouter(logger mlogger.Logger, router chi.Router, endpoint string) (Health, error) {
return healthimp.NewRouter(logger, router, endpoint), nil
}

View File

@@ -0,0 +1,10 @@
package health
type ServiceStatus string
const (
SSCreated ServiceStatus = "created"
SSStarting ServiceStatus = "starting"
SSRunning ServiceStatus = "ok"
SSTerminating ServiceStatus = "deactivating"
)

View File

@@ -0,0 +1,18 @@
package grpcimp
type Config struct {
Network string `yaml:"network"`
Address string `yaml:"address"`
EnableReflection bool `yaml:"enable_reflection"`
EnableHealth bool `yaml:"enable_health"`
MaxRecvMsgSize int `yaml:"max_recv_msg_size"`
MaxSendMsgSize int `yaml:"max_send_msg_size"`
TLS *TLSConfig `yaml:"tls"`
}
type TLSConfig struct {
CertFile string `yaml:"cert_file"`
KeyFile string `yaml:"key_file"`
CAFile string `yaml:"ca_file"`
RequireClientCert bool `yaml:"require_client_cert"`
}

View File

@@ -0,0 +1,103 @@
package grpcimp
import (
"context"
"strings"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
)
var (
metricsOnce sync.Once
grpcServerRequestsTotal *prometheus.CounterVec
grpcServerLatency *prometheus.HistogramVec
)
func initPrometheusMetrics() {
metricsOnce.Do(func() {
grpcServerRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "grpc_server_requests_total",
Help: "Total number of gRPC requests handled by the server.",
},
[]string{"grpc_service", "grpc_method", "grpc_type", "grpc_code"},
)
grpcServerLatency = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "grpc_server_handling_seconds",
Help: "Duration of gRPC requests handled by the server.",
Buckets: prometheus.DefBuckets,
},
[]string{"grpc_service", "grpc_method", "grpc_type", "grpc_code"},
)
prometheus.MustRegister(grpcServerRequestsTotal, grpcServerLatency)
})
}
func prometheusUnaryInterceptor() grpc.UnaryServerInterceptor {
initPrometheusMetrics()
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
start := time.Now()
resp, err := handler(ctx, req)
recordMetrics(info.FullMethod, "unary", time.Since(start), err)
return resp, err
}
}
func prometheusStreamInterceptor() grpc.StreamServerInterceptor {
initPrometheusMetrics()
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
err := handler(srv, ss)
recordMetrics(info.FullMethod, streamType(info), time.Since(start), err)
return err
}
}
func streamType(info *grpc.StreamServerInfo) string {
if info == nil {
return "stream"
}
if info.IsServerStream && info.IsClientStream {
return "bidi"
}
if info.IsServerStream {
return "server_stream"
}
if info.IsClientStream {
return "client_stream"
}
return "stream"
}
func recordMetrics(fullMethod string, callType string, duration time.Duration, err error) {
service, method := splitMethod(fullMethod)
code := status.Code(err).String()
grpcServerRequestsTotal.WithLabelValues(service, method, callType, code).Inc()
grpcServerLatency.WithLabelValues(service, method, callType, code).Observe(duration.Seconds())
}
func splitMethod(fullMethod string) (string, string) {
if fullMethod == "" {
return "unknown", "unknown"
}
if fullMethod[0] == '/' {
fullMethod = fullMethod[1:]
}
parts := strings.Split(fullMethod, "/")
if len(parts) < 2 {
return fullMethod, "unknown"
}
return parts[0], parts[1]
}

View File

@@ -0,0 +1,14 @@
package grpcimp
import (
"net"
"google.golang.org/grpc"
)
type Options struct {
UnaryInterceptors []grpc.UnaryServerInterceptor
StreamInterceptors []grpc.StreamServerInterceptor
ServerOptions []grpc.ServerOption
Listener net.Listener
}

View File

@@ -0,0 +1,293 @@
package grpcimp
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"os"
"sync"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
)
type routerError string
func (e routerError) Error() string {
return string(e)
}
type routerErrorWithCause struct {
message string
cause error
}
func (e *routerErrorWithCause) Error() string {
if e == nil {
return ""
}
if e.cause == nil {
return e.message
}
return e.message + ": " + e.cause.Error()
}
func (e *routerErrorWithCause) Unwrap() error {
if e == nil {
return nil
}
return e.cause
}
func newRouterErrorWithCause(message string, cause error) error {
return &routerErrorWithCause{
message: message,
cause: cause,
}
}
const (
errMsgAlreadyStarted = "grpc router already started"
errMsgListenFailed = "failed to listen on requested address"
errMsgNilContext = "nil context"
errMsgTLSMissingCertAndKey = "tls configuration requires cert_file and key_file"
errMsgLoadServerCertificate = "failed to load server certificate"
errMsgReadCAFile = "failed to read CA file"
errMsgAppendCACertificates = "failed to append CA certificates"
errMsgClientCertRequiresCAFile = "client certificate verification requested but ca_file is empty"
)
var (
errAlreadyStarted = routerError(errMsgAlreadyStarted)
errNilContext = routerError(errMsgNilContext)
errTLSMissingCertAndKey = routerError(errMsgTLSMissingCertAndKey)
errAppendCACertificates = routerError(errMsgAppendCACertificates)
errClientCertRequiresCAFile = routerError(errMsgClientCertRequiresCAFile)
)
type Router struct {
logger mlogger.Logger
config Config
server *grpc.Server
listener net.Listener
options *Options
mu sync.RWMutex
started bool
serveErr chan error
healthSrv *health.Server
}
func NewRouter(logger mlogger.Logger, cfg *Config, opts *Options) (*Router, error) {
if cfg == nil {
cfg = &Config{}
}
if opts == nil {
opts = &Options{}
}
network := cfg.Network
if network == "" {
network = "tcp"
}
address := cfg.Address
if address == "" {
address = ":0"
}
listener := opts.Listener
var err error
if listener == nil {
listener, err = net.Listen(network, address)
if err != nil {
return nil, newRouterErrorWithCause(errMsgListenFailed, err)
}
}
serverOpts := make([]grpc.ServerOption, 0, len(opts.ServerOptions)+4)
serverOpts = append(serverOpts, opts.ServerOptions...)
if cfg.MaxRecvMsgSize > 0 {
serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.MaxRecvMsgSize))
}
if cfg.MaxSendMsgSize > 0 {
serverOpts = append(serverOpts, grpc.MaxSendMsgSize(cfg.MaxSendMsgSize))
}
if creds, err := configureTLS(cfg.TLS); err != nil {
return nil, err
} else if creds != nil {
serverOpts = append(serverOpts, grpc.Creds(creds))
}
unaryInterceptors := append([]grpc.UnaryServerInterceptor{prometheusUnaryInterceptor()}, opts.UnaryInterceptors...)
streamInterceptors := append([]grpc.StreamServerInterceptor{prometheusStreamInterceptor()}, opts.StreamInterceptors...)
if len(unaryInterceptors) > 0 {
serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(unaryInterceptors...))
}
if len(streamInterceptors) > 0 {
serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(streamInterceptors...))
}
srv := grpc.NewServer(serverOpts...)
r := &Router{
logger: logger.Named("grpc"),
config: *cfg,
server: srv,
listener: listener,
options: opts,
serveErr: make(chan error, 1),
}
if cfg.EnableReflection {
reflection.Register(srv)
}
if cfg.EnableHealth {
r.healthSrv = health.NewServer()
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
healthpb.RegisterHealthServer(srv, r.healthSrv)
}
return r, nil
}
func (r *Router) Register(registration func(grpc.ServiceRegistrar)) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.started {
return errAlreadyStarted
}
registration(r.server)
return nil
}
func (r *Router) Start(ctx context.Context) error {
if ctx == nil {
return errNilContext
}
r.mu.Lock()
if r.started {
r.mu.Unlock()
return errAlreadyStarted
}
r.started = true
r.mu.Unlock()
if r.healthSrv != nil {
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
}
go func() {
<-ctx.Done()
r.logger.Info("Context cancelled, stopping gRPC server")
r.server.GracefulStop()
}()
go func() {
err := r.server.Serve(r.listener)
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
select {
case r.serveErr <- err:
default:
r.logger.Error("Failed to report gRPC serve error", zap.Error(err))
}
}
close(r.serveErr)
}()
r.logger.Info("gRPC server started", zap.String("network", r.listener.Addr().Network()), zap.String("address", r.listener.Addr().String()))
return nil
}
func (r *Router) Finish(ctx context.Context) error {
if ctx == nil {
return errNilContext
}
r.mu.RLock()
started := r.started
r.mu.RUnlock()
if !started {
return nil
}
if r.healthSrv != nil {
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
}
done := make(chan struct{})
go func() {
r.server.GracefulStop()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
r.logger.Warn("Graceful stop timed out, forcing stop", zap.Error(ctx.Err()))
r.server.Stop()
return ctx.Err()
}
if err, ok := <-r.serveErr; ok {
return err
}
return nil
}
func (r *Router) Addr() net.Addr {
return r.listener.Addr()
}
func (r *Router) Done() <-chan error {
return r.serveErr
}
func configureTLS(cfg *TLSConfig) (credentials.TransportCredentials, error) {
if cfg == nil {
return nil, nil
}
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, errTLSMissingCertAndKey
}
certificate, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, newRouterErrorWithCause(errMsgLoadServerCertificate, err)
}
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{certificate},
MinVersion: tls.VersionTLS12,
}
if cfg.CAFile != "" {
caPem, err := os.ReadFile(cfg.CAFile)
if err != nil {
return nil, newRouterErrorWithCause(errMsgReadCAFile, err)
}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(caPem); !ok {
return nil, errAppendCACertificates
}
tlsCfg.ClientCAs = certPool
if cfg.RequireClientCert {
tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
}
} else if cfg.RequireClientCert {
return nil, errClientCertRequiresCAFile
}
return credentials.NewTLS(tlsCfg), nil
}

View File

@@ -0,0 +1,150 @@
package grpcimp
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
const bufconnSize = 1024 * 1024
func newBufferedListener(t *testing.T) *bufconn.Listener {
t.Helper()
listener := bufconn.Listen(bufconnSize)
t.Cleanup(func() {
listener.Close()
})
return listener
}
func newTestRouter(t *testing.T, cfg *Config) *Router {
t.Helper()
logger := zap.NewNop()
if cfg == nil {
cfg = &Config{}
}
router, err := NewRouter(logger, cfg, &Options{Listener: newBufferedListener(t)})
require.NoError(t, err)
return router
}
func TestRouterStartAndFinish(t *testing.T) {
router := newTestRouter(t, &Config{})
doneCh := router.Done()
require.NotNil(t, doneCh)
require.NoError(t, router.Register(func(grpc.ServiceRegistrar) {}))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
require.NoError(t, router.Start(ctx))
addr := router.Addr()
require.NotNil(t, addr)
require.NotEmpty(t, addr.String())
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
defer finishCancel()
require.NoError(t, router.Finish(finishCtx))
select {
case err, ok := <-doneCh:
if ok {
require.NoError(t, err)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for done channel")
}
}
func TestRouterRejectsRegistrationAfterStart(t *testing.T) {
router := newTestRouter(t, &Config{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
require.NoError(t, router.Start(ctx))
doneCh := router.Done()
err := router.Register(func(grpc.ServiceRegistrar) {})
require.ErrorIs(t, err, errAlreadyStarted)
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
defer finishCancel()
require.NoError(t, router.Finish(finishCtx))
select {
case <-doneCh:
case <-time.After(time.Second):
t.Fatal("timed out waiting for done channel")
}
}
func TestRouterStartOnlyOnce(t *testing.T) {
router := newTestRouter(t, &Config{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
require.NoError(t, router.Start(ctx))
require.ErrorIs(t, router.Start(ctx), errAlreadyStarted)
doneCh := router.Done()
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
defer finishCancel()
require.NoError(t, router.Finish(finishCtx))
select {
case <-doneCh:
case <-time.After(time.Second):
t.Fatal("timed out waiting for done channel")
}
}
func TestRouterUsesProvidedListener(t *testing.T) {
logger := zap.NewNop()
listener := newBufferedListener(t)
cfg := &Config{}
router, err := NewRouter(logger, cfg, &Options{Listener: listener})
require.NoError(t, err)
actualListener, ok := router.listener.(*bufconn.Listener)
require.True(t, ok)
require.Same(t, listener, actualListener)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
require.NoError(t, router.Start(ctx))
doneCh := router.Done()
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
defer finishCancel()
require.NoError(t, router.Finish(finishCtx))
select {
case <-doneCh:
case <-time.After(time.Second):
t.Fatal("timed out waiting for done channel")
}
}

View File

@@ -0,0 +1,45 @@
package healthimp
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/tech/sendico/pkg/api/routers/health"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
)
type Router struct {
logger mlogger.Logger
status *Status
}
func (hr *Router) SetStatus(status health.ServiceStatus) {
hr.status.setStatus(status)
hr.logger.Info("New status set", zap.String("status", string(status)))
}
func (hr *Router) Finish() {
hr.status.Finish()
hr.logger.Debug("Stopped")
}
func (hr *Router) handle(w http.ResponseWriter, r *http.Request) {
hr.status.healthHandler()(w, r)
}
func NewRouter(logger mlogger.Logger, router chi.Router, endpoint string) *Router {
hr := Router{
logger: logger.Named("health_check"),
}
hr.status = StatusHandler(hr.logger)
logger.Debug("Installing healthcheck middleware...")
router.Group(func(r chi.Router) {
ep := endpoint + "/health"
r.Get(ep, hr.handle)
logger.Info("Health handler installed", zap.String("endpoint", ep))
})
return &hr
}

View File

@@ -0,0 +1,38 @@
package healthimp
import (
"net/http"
"github.com/tech/sendico/pkg/api/http/response"
"github.com/tech/sendico/pkg/api/routers/health"
"github.com/tech/sendico/pkg/mlogger"
)
type Status struct {
logger mlogger.Logger
status health.ServiceStatus
}
func (hs *Status) healthHandler() http.HandlerFunc {
return response.Ok(hs.logger, struct {
Status health.ServiceStatus `json:"status"`
}{
hs.status,
})
}
func (hr *Status) Finish() {
hr.logger.Info("Finished")
}
func (hs *Status) setStatus(status health.ServiceStatus) {
hs.status = status
}
func StatusHandler(logger mlogger.Logger) *Status {
hs := Status{
status: health.SSCreated,
logger: logger.Named("status"),
}
return &hs
}

View File

@@ -0,0 +1,66 @@
package messagingimp
import (
"context"
"github.com/tech/sendico/pkg/messaging"
mb "github.com/tech/sendico/pkg/messaging/broker"
me "github.com/tech/sendico/pkg/messaging/envelope"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.uber.org/zap"
)
type ChannelConsumer struct {
logger mlogger.Logger
broker mb.Broker
event model.NotificationEvent
ch <-chan me.Envelope
ctx context.Context
cancel context.CancelFunc
}
func (c *ChannelConsumer) ConsumeMessages(handleFunc messaging.MessageHandlerT) error {
c.logger.Info("Message consumer is ready")
for {
select {
case msg := <-c.ch:
if msg == nil { // nil message indicates the channel was closed
c.logger.Info("Consumer shutting down")
return nil
}
if err := handleFunc(c.ctx, msg); err != nil {
c.logger.Warn("Error processing message", zap.Error(err))
}
case <-c.ctx.Done():
c.logger.Info("Context done, shutting down")
return c.ctx.Err()
}
}
}
func (c *ChannelConsumer) Close() {
c.logger.Info("Shutting down...")
c.cancel()
if err := c.broker.Unsubscribe(c.event, c.ch); err != nil {
c.logger.Warn("Failed to unsubscribe", zap.Error(err))
}
}
func NewConsumer(logger mlogger.Logger, broker mb.Broker, event model.NotificationEvent) (*ChannelConsumer, error) {
ctx, cancel := context.WithCancel(context.Background())
ch, err := broker.Subscribe(event)
if err != nil {
logger.Warn("Failed to create channel consumer", zap.Error(err), zap.String("topic", event.ToString()))
cancel() // Ensure resources are released properly
return nil, err
}
return &ChannelConsumer{
logger: logger.Named("consumer").Named(event.ToString()),
broker: broker,
event: event,
ch: ch,
ctx: ctx,
cancel: cancel,
}, nil
}

View File

@@ -0,0 +1,67 @@
package messagingimp
import (
"context"
"errors"
"github.com/tech/sendico/pkg/messaging"
mb "github.com/tech/sendico/pkg/messaging/broker"
notifications "github.com/tech/sendico/pkg/messaging/notifications/processor"
mip "github.com/tech/sendico/pkg/messaging/producer"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
)
type MessagingRouter struct {
logger mlogger.Logger
messaging mb.Broker
consumers []messaging.Consumer
producer messaging.Producer
}
func (mr *MessagingRouter) consumeMessages(c messaging.Consumer, processor notifications.EnvelopeProcessor) {
if err := c.ConsumeMessages(processor.Process); err != nil {
if !errors.Is(err, context.Canceled) {
mr.logger.Warn("Error consuming messages", zap.Error(err), zap.String("event", processor.GetSubject().ToString()))
} else {
mr.logger.Info("Finishing as context has been cancelled", zap.String("event", processor.GetSubject().ToString()))
}
}
}
func (mr *MessagingRouter) Consumer(processor notifications.EnvelopeProcessor) error {
c, err := NewConsumer(mr.logger, mr.messaging, processor.GetSubject())
if err != nil {
mr.logger.Warn("Failed to register message consumer", zap.Error(err), zap.String("event", processor.GetSubject().ToString()))
return err
}
mr.consumers = append(mr.consumers, c)
go mr.consumeMessages(c, processor)
return nil
}
func (mr *MessagingRouter) Finish() {
mr.logger.Info("Closing consumer channels")
for _, consumer := range mr.consumers {
consumer.Close()
}
}
func (mr *MessagingRouter) Producer() messaging.Producer {
return mr.producer
}
func NewMessagingRouterImp(logger mlogger.Logger, config *messaging.Config) (*MessagingRouter, error) {
l := logger.Named("messaging")
broker, err := messaging.CreateMessagingBroker(l, config)
if err != nil {
l.Error("Failed to create messaging broker", zap.Error(err), zap.String("broker", string(config.Driver)))
return nil, err
}
return &MessagingRouter{
logger: l,
messaging: broker,
producer: mip.NewProducer(logger, broker),
consumers: make([]messaging.Consumer, 0),
}, nil
}

View File

@@ -0,0 +1,16 @@
package routers
import (
"github.com/tech/sendico/pkg/api/routers/internal/messagingimp"
"github.com/tech/sendico/pkg/messaging"
"github.com/tech/sendico/pkg/mlogger"
)
type Messaging interface {
messaging.Register
Finish()
}
func NewMessagingRouter(logger mlogger.Logger, config *messaging.Config) (Messaging, error) {
return messagingimp.NewMessagingRouterImp(logger, config)
}

202
api/pkg/auth/USAGE.md Normal file
View File

@@ -0,0 +1,202 @@
# Auth.Indexable Usage Guide
## Secure Reordering with Permission Checking
The `auth.Indexable` implementation adds **permission checking** to the generic reordering functionality using `EnforceBatch`.
- **Core Implementation**: `api/pkg/auth/indexable.go` - generic implementation with permission checking
- **Project Factory**: `api/pkg/auth/project_indexable.go` - convenient factory for projects
- **Key Feature**: Uses `EnforceBatch` to check permissions for all affected objects
## How It Works
### Permission Checking Flow
1. **Get current object** to find its index
2. **Determine affected objects** that will be shifted during reordering
3. **Check permissions** using `EnforceBatch` for all affected objects + target object
4. **Verify all permissions** - if any object lacks update permission, return error
5. **Proceed with reordering** only if all permissions are granted
### Key Differences from Basic Indexable
- **Additional parameter**: `accountRef` for permission checking
- **Permission validation**: All affected objects must have `ActionUpdate` permission
- **Security**: Prevents unauthorized reordering that could affect other users' data
## Usage
### 1. Using the Generic Auth.Indexable Implementation
```go
import "github.com/tech/sendico/pkg/auth"
// For any type that embeds model.Indexable, define helper functions:
createEmpty := func() *YourType {
return &YourType{}
}
getIndexable := func(obj *YourType) *model.Indexable {
return &obj.Indexable
}
// Create auth.IndexableDB with enforcer
indexableDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
// Use with account reference for permission checking
err := indexableDB.Reorder(ctx, accountRef, objectID, newIndex, filter)
```
### 2. Using the Project Factory (Recommended for Projects)
```go
import "github.com/tech/sendico/pkg/auth"
// Create auth.ProjectIndexableDB (automatically applies org filter)
projectDB := auth.NewProjectIndexableDB(repo, logger, enforcer, organizationRef)
// Reorder project with permission checking
err := projectDB.Reorder(ctx, accountRef, projectID, newIndex, repository.Query())
// Reorder with additional filters (combined with org filter)
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
err := projectDB.Reorder(ctx, accountRef, projectID, newIndex, additionalFilter)
```
## Examples for Different Types
### Project Auth.IndexableDB
```go
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
projectDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
orgFilter := repository.OrgFilter(organizationRef)
projectDB.Reorder(ctx, accountRef, projectID, 2, orgFilter)
```
### Status Auth.IndexableDB
```go
createEmpty := func() *model.Status {
return &model.Status{}
}
getIndexable := func(s *model.Status) *model.Indexable {
return &s.Indexable
}
statusDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
statusDB.Reorder(ctx, accountRef, statusID, 1, projectFilter)
```
### Task Auth.IndexableDB
```go
createEmpty := func() *model.Task {
return &model.Task{}
}
getIndexable := func(t *model.Task) *model.Indexable {
return &t.Indexable
}
taskDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
taskDB.Reorder(ctx, accountRef, taskID, 3, statusFilter)
```
## Permission Checking Details
### What Gets Checked
When reordering an object from index `A` to index `B`:
1. **Target object** - the object being moved
2. **Affected objects** - all objects whose indices will be shifted:
- Moving down: objects between `A+1` and `B` (shifted up by -1)
- Moving up: objects between `B` and `A-1` (shifted down by +1)
### Permission Requirements
- **Action**: `model.ActionUpdate`
- **Scope**: All affected objects must be `PermissionBoundStorable`
- **Result**: If any object lacks permission, the entire operation fails
### Error Handling
```go
// Permission denied error
if err != nil {
if strings.Contains(err.Error(), "accessDenied") {
// Handle permission denied
}
}
```
## Security Benefits
### ✅ **Comprehensive Permission Checking**
- Checks permissions for **all affected objects**, not just the target
- Prevents unauthorized reordering that could affect other users' data
- Uses efficient `EnforceBatch` for bulk permission checking
### ✅ **Type Safety**
- Generic implementation works with any `Indexable` struct
- Compile-time type checking
- No runtime type assertions
### ✅ **Flexible Filtering**
- Single `builder.Query` parameter for scoping
- Can combine organization filters with additional criteria
- Project factory automatically applies organization filtering
### ✅ **Clean Architecture**
- Separates permission logic from reordering logic
- Easy to test with mock enforcers
- Follows existing auth patterns
## Testing
### Mock Enforcer Setup
```go
mockEnforcer := &MockEnforcer{}
// Grant all permissions
permissions := map[primitive.ObjectID]bool{
objectID1: true,
objectID2: true,
}
mockEnforcer.On("EnforceBatch", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(permissions, nil)
// Deny specific permission
permissions[objectID2] = false
mockEnforcer.On("EnforceBatch", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(permissions, nil)
```
### Test Scenarios
-**Permission granted** - reordering succeeds
-**Permission denied** - reordering fails with access denied error
- 🔄 **No change needed** - early return, minimal permission checking
- 🏢 **Organization filtering** - automatic org scope for projects
## Comparison: Basic vs Auth.Indexable
| Feature | Basic Indexable | Auth.Indexable |
|---------|----------------|----------------|
| Permission checking | ❌ No | ✅ Yes |
| Account parameter | ❌ No | ✅ Required |
| Security | ❌ None | ✅ Comprehensive |
| Performance | ✅ Fast | ⚠️ Slower (permission checks) |
| Use case | Internal operations | User-facing operations |
## Best Practices
1. **Use Auth.Indexable** for user-facing reordering operations
2. **Use Basic Indexable** for internal/system operations
3. **Always provide account reference** for proper permission checking
4. **Test permission scenarios** thoroughly with mock enforcers
5. **Handle permission errors** gracefully in user interfaces
That's it! **Secure, type-safe reordering** with comprehensive permission checking using `EnforceBatch`.

View File

@@ -0,0 +1,3 @@
package anyobject
const ID = "*"

View File

@@ -0,0 +1,35 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArchivableDB implements archive operations with permission checking
type ArchivableDB[T model.PermissionBoundStorable] interface {
// SetArchived sets the archived status of an entity with permission checking
SetArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID, archived bool) error
// IsArchived checks if an entity is archived with permission checking
IsArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID) (bool, error)
// Archive archives an entity with permission checking (sets archived to true)
Archive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
// Unarchive unarchives an entity with permission checking (sets archived to false)
Unarchive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
}
// NewArchivableDB creates a new auth.ArchivableDB instance
func NewArchivableDB[T model.PermissionBoundStorable](
dbImp *template.DBImp[T],
logger mlogger.Logger,
enforcer Enforcer,
createEmpty func() T,
getArchivable func(T) model.Archivable,
) ArchivableDB[T] {
return newArchivableDBImp(dbImp, logger, enforcer, createEmpty, getArchivable)
}

View File

@@ -0,0 +1,107 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// ArchivableDB implements archive operations with permission checking
type ArchivableDBImp[T model.PermissionBoundStorable] struct {
dbImp *template.DBImp[T]
logger mlogger.Logger
enforcer Enforcer
createEmpty func() T
getArchivable func(T) model.Archivable
}
// NewArchivableDB creates a new auth.ArchivableDB instance
func newArchivableDBImp[T model.PermissionBoundStorable](
dbImp *template.DBImp[T],
logger mlogger.Logger,
enforcer Enforcer,
createEmpty func() T,
getArchivable func(T) model.Archivable,
) ArchivableDB[T] {
return &ArchivableDBImp[T]{
dbImp: dbImp,
logger: logger.Named("archivable"),
enforcer: enforcer,
createEmpty: createEmpty,
getArchivable: getArchivable,
}
}
// SetArchived sets the archived status of an entity with permission checking
func (db *ArchivableDBImp[T]) SetArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID, archived bool) error {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
db.logger.Warn("Failed to enforce object permission", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
return err
}
// Get the object to check current archived status
obj := db.createEmpty()
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for setting archived status", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
return err
}
// Extract archivable from the object
archivable := db.getArchivable(obj)
currentArchived := archivable.IsArchived()
if currentArchived == archived {
db.logger.Debug("No change needed - same archived status", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
return nil // No change needed
}
// Set the archived status
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to set archived status on object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
return err
}
db.logger.Debug("Successfully set archived status on object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
return nil
}
// IsArchived checks if an entity is archived with permission checking
func (db *ArchivableDBImp[T]) IsArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID) (bool, error) {
// // Check permissions using single Enforce
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
db.logger.Debug("Permission denied for checking archived status", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionRead)))
return false, merrors.AccessDenied("read", "object", objectRef)
}
obj := db.createEmpty()
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for checking archived status", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return false, err
}
archivable := db.getArchivable(obj)
return archivable.IsArchived(), nil
}
// Archive archives an entity with permission checking (sets archived to true)
func (db *ArchivableDBImp[T]) Archive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, accountRef, objectRef, true)
}
// Unarchive unarchives an entity with permission checking (sets archived to false)
func (db *ArchivableDBImp[T]) Unarchive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, accountRef, objectRef, false)
}

12
api/pkg/auth/config.go Normal file
View File

@@ -0,0 +1,12 @@
package auth
import "github.com/tech/sendico/pkg/model"
type EnforcerType string
const (
Casbin EnforcerType = "casbin"
Native EnforcerType = "native"
)
type Config = model.DriverConfig[EnforcerType]

View File

@@ -0,0 +1,8 @@
package customizable
import (
"github.com/tech/sendico/pkg/model"
)
type DB[T model.PermissionBoundStorable] interface {
}

View File

@@ -0,0 +1,8 @@
package customizable
import (
"github.com/tech/sendico/pkg/model"
)
type Manager[T model.PermissionBoundStorable] interface {
}

38
api/pkg/auth/db.go Normal file
View File

@@ -0,0 +1,38 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
type ProtectedDB[T model.PermissionBoundStorable] interface {
Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error
InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error
Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error
Update(ctx context.Context, accountRef primitive.ObjectID, object T) error
Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error
PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error)
Unprotected() template.DB[T]
ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error)
}
func CreateDB[T model.PermissionBoundStorable](
ctx context.Context,
l mlogger.Logger,
pdb policy.DB,
enforcer Enforcer,
collection mservice.Type,
db *mongo.Database,
) (ProtectedDB[T], error) {
return CreateDBImp[T](ctx, l, pdb, enforcer, collection, db)
}

51
api/pkg/auth/dbab.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type AccountBoundDB[T model.AccountBoundStorable] interface {
Create(ctx context.Context, accountRef primitive.ObjectID, object T) error
Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error
Update(ctx context.Context, accountRef primitive.ObjectID, object T) error
Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error
Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
DeleteMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) error
FindOne(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, result T) error
ListIDs(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error)
ListAccountBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID, query builder.Query) ([]model.AccountBoundStorable, error)
}
func CreateAccountBound[T model.AccountBoundStorable](
ctx context.Context,
logger mlogger.Logger,
pdb policy.DB,
enforcer Enforcer,
collection mservice.Type,
db *mongo.Database,
) (AccountBoundDB[T], error) {
logger = logger.Named("account_bound")
var policy model.PolicyDescription
if err := pdb.GetBuiltInPolicy(ctx, mservice.Organizations, &policy); err != nil {
logger.Warn("Failed to fetch organization policy description", zap.Error(err))
return nil, err
}
res := &AccountBoundDBImp[T]{
Logger: logger,
DBImp: template.Create[T](logger, collection, db),
Enforcer: enforcer,
PermissionRef: policy.ID,
Collection: collection,
}
return res, nil
}

319
api/pkg/auth/dbimp.go Normal file
View File

@@ -0,0 +1,319 @@
package auth
import (
"context"
"errors"
"fmt"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type ProtectedDBImp[T model.PermissionBoundStorable] struct {
DBImp *template.DBImp[T]
Enforcer Enforcer
PermissionRef primitive.ObjectID
Collection mservice.Type
}
func (db *ProtectedDBImp[T]) enforce(ctx context.Context, action model.Action, object model.PermissionBoundStorable, accountRef, objectRef primitive.ObjectID) error {
res, err := db.Enforcer.Enforce(ctx, object.GetPermissionRef(), accountRef, object.GetOrganizationRef(), objectRef, action)
if err != nil {
db.DBImp.Logger.Warn("Failed to enforce permission",
zap.Error(err), mzap.ObjRef("permission_ref", object.GetPermissionRef()),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return err
}
if !res {
db.DBImp.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", object.GetPermissionRef()),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return merrors.AccessDenied(db.Collection, string(action), objectRef)
}
return nil
}
func (db *ProtectedDBImp[T]) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error {
db.DBImp.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
if object.GetPermissionRef() == primitive.NilObjectID {
object.SetPermissionRef(db.PermissionRef)
}
object.SetOrganizationRef(organizationRef)
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
return err
}
if err := db.DBImp.Create(ctx, object); err != nil {
db.DBImp.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
return err
}
db.DBImp.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
return nil
}
func (db *ProtectedDBImp[T]) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error {
if len(objects) == 0 {
return nil
}
db.DBImp.Logger.Debug("Attempting to insert many objects", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
zap.Int("count", len(objects)))
// Set permission and organization refs for all objects and enforce permissions
for _, object := range objects {
if object.GetPermissionRef() == primitive.NilObjectID {
object.SetPermissionRef(db.PermissionRef)
}
object.SetOrganizationRef(organizationRef)
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
return err
}
}
if err := db.DBImp.InsertMany(ctx, objects); err != nil {
db.DBImp.Logger.Warn("Failed to insert many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
zap.Int("count", len(objects)))
return err
}
db.DBImp.Logger.Debug("Successfully inserted many objects", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
zap.Int("count", len(objects)))
return nil
}
func (db *ProtectedDBImp[T]) enforceObject(ctx context.Context, action model.Action, accountRef, objectRef primitive.ObjectID) error {
l, err := db.ListIDs(ctx, action, accountRef, repository.IDFilter(objectRef))
if err != nil {
db.DBImp.Logger.Warn("Error occured while checking access rights", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return err
}
if len(l) == 0 {
db.DBImp.Logger.Debug("Access denied", zap.String("action", string(action)), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return merrors.AccessDenied(db.Collection, string(action), objectRef)
}
return nil
}
func (db *ProtectedDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error {
db.DBImp.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
if err := db.enforceObject(ctx, model.ActionRead, accountRef, objectRef); err != nil {
return err
}
if err := db.DBImp.Get(ctx, objectRef, result); err != nil {
db.DBImp.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
return err
}
db.DBImp.Logger.Debug("Successfully retrieved object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", result.GetOrganizationRef()),
mzap.StorableRef(result), mzap.ObjRef("permission_ref", result.GetPermissionRef()))
return nil
}
func (db *ProtectedDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error {
db.DBImp.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object))
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, *object.GetID()); err != nil {
return err
}
if err := db.DBImp.Update(ctx, object); err != nil {
db.DBImp.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
return err
}
db.DBImp.Logger.Debug("Successfully updated object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
mzap.StorableRef(object), mzap.ObjRef("permission_ref", object.GetPermissionRef()))
return nil
}
func (db *ProtectedDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
db.DBImp.Logger.Debug("Attempting to delete object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
return err
}
if err := db.DBImp.Delete(ctx, objectRef); err != nil {
db.DBImp.Logger.Warn("Failed to delete object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
db.DBImp.Logger.Debug("Successfully deleted object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return nil
}
func (db *ProtectedDBImp[T]) ListIDs(
ctx context.Context,
action model.Action,
accountRef primitive.ObjectID,
query builder.Query,
) ([]primitive.ObjectID, error) {
db.DBImp.Logger.Debug("Attempting to list object IDs",
mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
// 1. Fetch all candidate IDs from the underlying DB
allIDs, err := db.DBImp.ListPermissionBound(ctx, query)
if err != nil {
db.DBImp.Logger.Warn("Failed to list object IDs", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
return nil, err
}
if len(allIDs) == 0 {
db.DBImp.Logger.Debug("No objects found matching filter", mzap.ObjRef("account_ref", accountRef),
zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
return []primitive.ObjectID{}, merrors.NoData(fmt.Sprintf("no %s found", db.Collection))
}
// 2. Check read permission for each ID
var allowedIDs []primitive.ObjectID
for _, desc := range allIDs {
enforceErr := db.enforce(ctx, action, desc, accountRef, *desc.GetID())
if enforceErr == nil {
allowedIDs = append(allowedIDs, *desc.GetID())
} else if !errors.Is(err, merrors.ErrAccessDenied) {
// If the error is something other than AccessDenied, we want to fail
db.DBImp.Logger.Warn("Error while enforcing read permission", zap.Error(enforceErr),
mzap.ObjRef("permission_ref", desc.GetPermissionRef()), zap.String("action", string(action)),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("collection", string(db.Collection)),
)
return nil, enforceErr
}
// If AccessDenied, we simply skip that ID.
}
db.DBImp.Logger.Debug("Successfully enforced read permission on IDs", zap.Int("fetched_count", len(allIDs)),
zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef),
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
// 3. Return only the IDs that passed permission checks
return allowedIDs, nil
}
func (db *ProtectedDBImp[T]) Unprotected() template.DB[T] {
return db.DBImp
}
func (db *ProtectedDBImp[T]) DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
return err
}
if err := db.DBImp.DeleteCascade(ctx, objectRef); err != nil {
db.DBImp.Logger.Warn("Failed to delete dependent object", zap.Error(err))
return err
}
return nil
}
func CreateDBImp[T model.PermissionBoundStorable](
ctx context.Context,
l mlogger.Logger,
pdb policy.DB,
enforcer Enforcer,
collection mservice.Type,
db *mongo.Database,
) (*ProtectedDBImp[T], error) {
logger := l.Named("protected")
var policy model.PolicyDescription
if err := pdb.GetBuiltInPolicy(ctx, collection, &policy); err != nil {
logger.Warn("Failed to fetch policy description", zap.Error(err), zap.String("resource_type", string(collection)))
return nil, err
}
p := &ProtectedDBImp[T]{
DBImp: template.Create[T](logger, collection, db),
PermissionRef: policy.ID,
Collection: collection,
Enforcer: enforcer,
}
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: storable.OrganizationRefField, Sort: ri.Asc}},
}); err != nil {
logger.Warn("Failed to create index", zap.Error(err), zap.String("resource_type", string(collection)))
return nil, err
}
return p, nil
}
func (db *ProtectedDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
db.DBImp.Logger.Debug("Attempting to patch object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, objectRef); err != nil {
return err
}
if err := db.DBImp.Repository.Patch(ctx, objectRef, patch); err != nil {
db.DBImp.Logger.Warn("Failed to patch object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
db.DBImp.Logger.Debug("Successfully patched object",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return nil
}
func (db *ProtectedDBImp[T]) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
db.DBImp.Logger.Debug("Attempting to patch many objects",
mzap.ObjRef("account_ref", accountRef), zap.Any("filter", query.BuildQuery()))
ids, err := db.ListIDs(ctx, model.ActionUpdate, accountRef, query)
if err != nil {
return 0, err
}
if len(ids) == 0 {
return 0, nil
}
values := make([]any, len(ids))
for i, id := range ids {
values[i] = id
}
idFilter := repository.Query().In(repository.IDField(), values...)
finalQuery := query.And(idFilter)
modified, err := db.DBImp.Repository.PatchMany(ctx, finalQuery, patch)
if err != nil {
db.DBImp.Logger.Warn("Failed to patch many objects", zap.Error(err),
mzap.ObjRef("account_ref", accountRef))
return 0, err
}
db.DBImp.Logger.Debug("Successfully patched many objects",
mzap.ObjRef("account_ref", accountRef), zap.Int("modified_count", modified))
return modified, nil
}

420
api/pkg/auth/dbimpab.go Normal file
View File

@@ -0,0 +1,420 @@
package auth
import (
"context"
"errors"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type AccountBoundDBImp[T model.AccountBoundStorable] struct {
Logger mlogger.Logger
DBImp *template.DBImp[T]
Enforcer Enforcer
PermissionRef primitive.ObjectID
Collection mservice.Type
}
func (db *AccountBoundDBImp[T]) enforce(ctx context.Context, action model.Action, object model.AccountBoundStorable, accountRef primitive.ObjectID) error {
// FIRST: Check if the object's AccountRef equals the calling accountRef - if so, ALLOW
objectAccountRef := object.GetAccountRef()
if objectAccountRef != nil && *objectAccountRef == accountRef {
db.Logger.Debug("Access granted - object belongs to calling account",
mzap.ObjRef("object_account_ref", *objectAccountRef),
mzap.ObjRef("calling_account_ref", accountRef),
zap.String("action", string(action)))
return nil
}
// SECOND: If not owned by calling account, check organization-level permissions
organizationRef := object.GetOrganizationRef()
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, organizationRef, action)
if err != nil {
db.Logger.Warn("Failed to enforce permission",
zap.Error(err), mzap.ObjRef("permission_ref", db.PermissionRef),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
zap.String("action", string(action)))
return err
}
if !res {
db.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", db.PermissionRef),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
zap.String("action", string(action)))
return merrors.AccessDenied(db.Collection, string(action), primitive.NilObjectID)
}
return nil
}
func (db *AccountBoundDBImp[T]) enforceInterface(ctx context.Context, action model.Action, object model.AccountBoundStorable, accountRef primitive.ObjectID) error {
// FIRST: Check if the object's AccountRef equals the calling accountRef - if so, ALLOW
objectAccountRef := object.GetAccountRef()
if objectAccountRef != nil && *objectAccountRef == accountRef {
db.Logger.Debug("Access granted - object belongs to calling account",
mzap.ObjRef("object_account_ref", *objectAccountRef),
mzap.ObjRef("calling_account_ref", accountRef),
zap.String("action", string(action)))
return nil
}
// SECOND: If not owned by calling account, check organization-level permissions
organizationRef := object.GetOrganizationRef()
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, organizationRef, action)
if err != nil {
db.Logger.Warn("Failed to enforce permission",
zap.Error(err), mzap.ObjRef("permission_ref", db.PermissionRef),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
zap.String("action", string(action)))
return err
}
if !res {
db.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", db.PermissionRef),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
zap.String("action", string(action)))
return merrors.AccessDenied(db.Collection, string(action), primitive.NilObjectID)
}
return nil
}
func (db *AccountBoundDBImp[T]) Create(ctx context.Context, accountRef primitive.ObjectID, object T) error {
orgRef := object.GetOrganizationRef()
db.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
// Check organization update permission for create operations
if err := db.enforce(ctx, model.ActionUpdate, object, accountRef); err != nil {
return err
}
if err := db.DBImp.Create(ctx, object); err != nil {
db.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
return err
}
db.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
return nil
}
func (db *AccountBoundDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error {
db.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
// First get the object to check its organization
if err := db.DBImp.Get(ctx, objectRef, result); err != nil {
db.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
return err
}
// Check organization read permission
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
return err
}
db.Logger.Debug("Successfully retrieved object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", result.GetOrganizationRef()), zap.String("collection", string(db.Collection)))
return nil
}
func (db *AccountBoundDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error {
db.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object))
// Check organization update permission
if err := db.enforce(ctx, model.ActionUpdate, object, accountRef); err != nil {
return err
}
if err := db.DBImp.Update(ctx, object); err != nil {
db.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
return err
}
db.Logger.Debug("Successfully updated object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
return nil
}
func (db *AccountBoundDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
db.Logger.Debug("Attempting to patch object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
// First get the object to check its organization
objs, err := db.DBImp.Repository.ListAccountBound(ctx, repository.IDFilter(objectRef))
if err != nil {
db.Logger.Warn("Failed to get object for permission check when deleting", zap.Error(err), mzap.ObjRef("object_ref", objectRef))
return err
}
if len(objs) == 0 {
db.Logger.Debug("Permission denied for deletion", mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("account_ref", accountRef))
return merrors.AccessDenied(db.Collection, string(model.ActionDelete), objectRef)
}
// Check organization update permission
if err := db.enforce(ctx, model.ActionUpdate, objs[0], accountRef); err != nil {
return err
}
if err := db.DBImp.Patch(ctx, objectRef, patch); err != nil {
db.Logger.Warn("Failed to patch object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
return err
}
db.Logger.Debug("Successfully patched object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return nil
}
func (db *AccountBoundDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
db.Logger.Debug("Attempting to delete object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
// First get the object to check its organization
objs, err := db.DBImp.Repository.ListAccountBound(ctx, repository.IDFilter(objectRef))
if err != nil {
db.Logger.Warn("Failed to get object for permission check when deleting", zap.Error(err), mzap.ObjRef("object_ref", objectRef))
return err
}
if len(objs) == 0 {
db.Logger.Debug("Permission denied for deletion", mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("account_ref", accountRef))
return merrors.AccessDenied(db.Collection, string(model.ActionDelete), objectRef)
}
// Check organization update permission for delete operations
if err := db.enforce(ctx, model.ActionUpdate, objs[0], accountRef); err != nil {
return err
}
if err := db.DBImp.Delete(ctx, objectRef); err != nil {
db.Logger.Warn("Failed to delete object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
return err
}
db.Logger.Debug("Successfully deleted object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return nil
}
func (db *AccountBoundDBImp[T]) DeleteMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) error {
db.Logger.Debug("Attempting to delete many objects", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
// Get all candidate objects for batch permission checking
allObjects, err := db.DBImp.Repository.ListPermissionBound(ctx, query)
if err != nil {
db.Logger.Warn("Failed to list objects for delete many", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
// Use batch enforcement for efficiency
allowedResults, err := db.Enforcer.EnforceBatch(ctx, allObjects, accountRef, model.ActionUpdate)
if err != nil {
db.Logger.Warn("Failed to enforce batch permissions for delete many", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
// Build query for objects that passed permission check
var allowedIDs []primitive.ObjectID
for _, obj := range allObjects {
if allowedResults[*obj.GetID()] {
allowedIDs = append(allowedIDs, *obj.GetID())
}
}
if len(allowedIDs) == 0 {
db.Logger.Debug("No objects allowed for deletion", mzap.ObjRef("account_ref", accountRef))
return nil
}
// Delete only the allowed objects
allowedQuery := query.And(repository.Query().In(repository.IDField(), allowedIDs))
if err := db.DBImp.DeleteMany(ctx, allowedQuery); err != nil {
db.Logger.Warn("Failed to delete many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
db.Logger.Debug("Successfully deleted many objects", mzap.ObjRef("account_ref", accountRef), zap.Int("count", len(allowedIDs)))
return nil
}
func (db *AccountBoundDBImp[T]) FindOne(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, result T) error {
db.Logger.Debug("Attempting to find one object", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
// For FindOne, we need to check read permission after finding the object
if err := db.DBImp.FindOne(ctx, query, result); err != nil {
db.Logger.Warn("Failed to find one object", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
// Check organization read permission for the found object
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
return err
}
db.Logger.Debug("Successfully found one object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", result.GetOrganizationRef()))
return nil
}
func (db *AccountBoundDBImp[T]) ListIDs(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
db.Logger.Debug("Attempting to list object IDs", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
// Get all candidate objects for batch permission checking
allObjects, err := db.DBImp.Repository.ListPermissionBound(ctx, query)
if err != nil {
db.Logger.Warn("Failed to list objects for ID filtering", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return nil, err
}
// Use batch enforcement for efficiency
allowedResults, err := db.Enforcer.EnforceBatch(ctx, allObjects, accountRef, model.ActionRead)
if err != nil {
db.Logger.Warn("Failed to enforce batch permissions for ID listing", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return nil, err
}
// Filter to only allowed object IDs
var allowedIDs []primitive.ObjectID
for _, obj := range allObjects {
if allowedResults[*obj.GetID()] {
allowedIDs = append(allowedIDs, *obj.GetID())
}
}
db.Logger.Debug("Successfully filtered object IDs", zap.Int("total_count", len(allObjects)),
zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef))
return allowedIDs, nil
}
func (db *AccountBoundDBImp[T]) ListAccountBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID, query builder.Query) ([]model.AccountBoundStorable, error) {
db.Logger.Debug("Attempting to list account bound objects", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
// Build query to find objects where accountRef matches OR is null/absent
accountQuery := repository.WithOrg(accountRef, organizationRef)
// Combine with the provided query
finalQuery := query.And(accountQuery)
// Get all candidate objects
allObjects, err := db.DBImp.Repository.ListAccountBound(ctx, finalQuery)
if err != nil {
db.Logger.Warn("Failed to list account bound objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return nil, err
}
// Filter objects based on read permissions (AccountBoundStorable doesn't have permission info, so we check organization level)
var allowedObjects []model.AccountBoundStorable
for _, obj := range allObjects {
if err := db.enforceInterface(ctx, model.ActionRead, obj, accountRef); err == nil {
allowedObjects = append(allowedObjects, obj)
} else if !errors.Is(err, merrors.ErrAccessDenied) {
// If the error is something other than AccessDenied, we want to fail
db.Logger.Warn("Error while enforcing read permission", zap.Error(err), mzap.ObjRef("object_ref", *obj.GetID()))
return nil, err
}
// If AccessDenied, we simply skip that object
}
db.Logger.Debug("Successfully filtered account bound objects", zap.Int("total_count", len(allObjects)),
zap.Int("allowed_count", len(allowedObjects)), mzap.ObjRef("account_ref", accountRef))
return allowedObjects, nil
}
func (db *AccountBoundDBImp[T]) GetByAccountRef(ctx context.Context, accountRef primitive.ObjectID, result T) error {
db.Logger.Debug("Attempting to get object by account ref", mzap.ObjRef("account_ref", accountRef))
// Build query to find objects where accountRef matches OR is null/absent
query := repository.WithoutOrg(accountRef)
if err := db.DBImp.FindOne(ctx, query, result); err != nil {
db.Logger.Warn("Failed to get object by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
// Check organization read permission for the found object
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
return err
}
db.Logger.Debug("Successfully retrieved object by account ref", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", result.GetOrganizationRef()))
return nil
}
func (db *AccountBoundDBImp[T]) DeleteByAccountRef(ctx context.Context, accountRef primitive.ObjectID) error {
db.Logger.Debug("Attempting to delete objects by account ref", mzap.ObjRef("account_ref", accountRef))
// Build query to find objects where accountRef matches OR is null/absent
query := repository.WithoutOrg(accountRef)
// Get all candidate objects for individual permission checking
allObjects, err := db.DBImp.Repository.ListAccountBound(ctx, query)
if err != nil {
db.Logger.Warn("Failed to list objects for delete by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
// Check permissions for each object individually (AccountBoundStorable doesn't have permission info)
var allowedIDs []primitive.ObjectID
for _, obj := range allObjects {
if err := db.enforceInterface(ctx, model.ActionUpdate, obj, accountRef); err == nil {
allowedIDs = append(allowedIDs, *obj.GetID())
} else if !errors.Is(err, merrors.ErrAccessDenied) {
// If the error is something other than AccessDenied, we want to fail
db.Logger.Warn("Error while enforcing update permission", zap.Error(err), mzap.ObjRef("object_ref", *obj.GetID()))
return err
}
// If AccessDenied, we simply skip that object
}
if len(allowedIDs) == 0 {
db.Logger.Debug("No objects allowed for deletion by account ref", mzap.ObjRef("account_ref", accountRef))
return nil
}
// Delete only the allowed objects
allowedQuery := query.And(repository.Query().In(repository.IDField(), allowedIDs))
if err := db.DBImp.DeleteMany(ctx, allowedQuery); err != nil {
db.Logger.Warn("Failed to delete objects by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
return err
}
db.Logger.Debug("Successfully deleted objects by account ref", mzap.ObjRef("account_ref", accountRef), zap.Int("count", len(allowedIDs)))
return nil
}
func (db *AccountBoundDBImp[T]) DeleteCascade(ctx context.Context, objectRef primitive.ObjectID) error {
return db.DBImp.DeleteCascade(ctx, objectRef)
}
// CreateAccountBoundImp creates a concrete AccountBoundDBImp instance for internal use
func CreateAccountBoundImp[T model.AccountBoundStorable](
ctx context.Context,
logger mlogger.Logger,
pdb policy.DB,
enforcer Enforcer,
collection mservice.Type,
db *mongo.Database,
) (*AccountBoundDBImp[T], error) {
logger = logger.Named("account_bound")
var policy model.PolicyDescription
if err := pdb.GetBuiltInPolicy(ctx, mservice.Organizations, &policy); err != nil {
logger.Warn("Failed to fetch organization policy description", zap.Error(err))
return nil, err
}
res := &AccountBoundDBImp[T]{
Logger: logger,
DBImp: template.Create[T](logger, collection, db),
Enforcer: enforcer,
PermissionRef: policy.ID,
Collection: collection,
}
return res, nil
}

View File

@@ -0,0 +1,81 @@
package auth
import (
"errors"
"testing"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// TestAccountBoundDBImp_Enforce tests the enforce method
func TestAccountBoundDBImp_Enforce(t *testing.T) {
logger := mlogger.Logger(zap.NewNop())
db := &AccountBoundDBImp[model.AccountBoundStorable]{
Logger: logger,
PermissionRef: primitive.NewObjectID(),
Collection: "test_collection",
}
t.Run("EnforceMethodExists", func(t *testing.T) {
// Test that the enforce method exists and can be called
// This is a basic test to ensure the method signature is correct
assert.NotNil(t, db.enforce)
})
t.Run("PermissionRefSet", func(t *testing.T) {
// Test that PermissionRef is properly set
assert.NotEqual(t, primitive.NilObjectID, db.PermissionRef)
})
t.Run("CollectionSet", func(t *testing.T) {
// Test that Collection is properly set
assert.Equal(t, "test_collection", string(db.Collection))
})
}
// TestAccountBoundDBImp_InterfaceCompliance tests that the struct implements required interfaces
func TestAccountBoundDBImp_InterfaceCompliance(t *testing.T) {
logger := mlogger.Logger(zap.NewNop())
db := &AccountBoundDBImp[model.AccountBoundStorable]{
Logger: logger,
PermissionRef: primitive.NewObjectID(),
Collection: "test_collection",
}
t.Run("StructInitialization", func(t *testing.T) {
// Test that the struct can be initialized
assert.NotNil(t, db)
assert.NotNil(t, db.Logger)
assert.NotEqual(t, primitive.NilObjectID, db.PermissionRef)
assert.NotEmpty(t, db.Collection)
})
t.Run("LoggerInitialization", func(t *testing.T) {
// Test that logger is properly initialized
assert.NotNil(t, db.Logger)
})
}
// TestAccountBoundDBImp_ErrorHandling tests error handling patterns
func TestAccountBoundDBImp_ErrorHandling(t *testing.T) {
t.Run("AccessDeniedError", func(t *testing.T) {
// Test that AccessDenied error is properly created
err := merrors.AccessDenied("test_collection", "read", primitive.NilObjectID)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
})
t.Run("ErrorTypeChecking", func(t *testing.T) {
// Test error type checking
accessDeniedErr := merrors.AccessDenied("test", "read", primitive.NilObjectID)
otherErr := errors.New("other error")
assert.True(t, errors.Is(accessDeniedErr, merrors.ErrAccessDenied))
assert.False(t, errors.Is(otherErr, merrors.ErrAccessDenied))
})
}

32
api/pkg/auth/enforcer.go Normal file
View File

@@ -0,0 +1,32 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type Enforcer interface {
// Enforce checks if accountRef can do `action` on objectRef in an org (domainRef).
Enforce(
ctx context.Context,
permissionRef, accountRef, orgRef, objectRef primitive.ObjectID,
action model.Action,
) (bool, error)
// Enforce batch of objects
EnforceBatch(
ctx context.Context,
objectRefs []model.PermissionBoundStorable,
accountRef primitive.ObjectID,
action model.Action,
) (map[primitive.ObjectID]bool, error)
// GetRoles returns the user's roles in a given org domain, plus any partial scopes if relevant.
GetRoles(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, error)
// GetPermissions returns all effective permissions (with effect, object scoping) for a user in org domain.
// Merges from all roles the user holds, plus any denies/exceptions.
GetPermissions(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, []model.Permission, error)
}

52
api/pkg/auth/factory.go Normal file
View File

@@ -0,0 +1,52 @@
package auth
import (
"github.com/tech/sendico/pkg/auth/internal/casbin"
"github.com/tech/sendico/pkg/auth/internal/native"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
func CreateAuth(
logger mlogger.Logger,
client *mongo.Client,
db *mongo.Database,
pdb policy.DB,
rdb role.DB,
config *Config,
) (Enforcer, Manager, error) {
lg := logger.Named("auth")
lg.Debug("Creating enforcer...", zap.String("driver", string(config.Driver)))
l := lg.Named(string(config.Driver))
if config.Driver == Casbin {
enforcer, err := casbin.NewEnforcer(l, client, config.Settings)
if err != nil {
lg.Warn("Failed to create enforcer", zap.Error(err))
return nil, nil, err
}
manager, err := casbin.NewManager(l, pdb, rdb, enforcer, config.Settings)
if err != nil {
lg.Warn("Failed to create managment interface", zap.Error(err))
return nil, nil, err
}
return enforcer, manager, nil
}
if config.Driver == Native {
enforcer, err := native.NewEnforcer(l, db)
if err != nil {
lg.Warn("Failed to create enforcer", zap.Error(err))
return nil, nil, err
}
manager, err := native.NewManager(l, pdb, rdb, enforcer)
if err != nil {
lg.Warn("Failed to create managment interface", zap.Error(err))
return nil, nil, err
}
return enforcer, manager, nil
}
return nil, nil, merrors.InvalidArgument("Unknown enforcer type: " + string(config.Driver))
}

61
api/pkg/auth/helper.go Normal file
View File

@@ -0,0 +1,61 @@
package auth
import (
"context"
"errors"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
func enforceObject[T model.PermissionBoundStorable](ctx context.Context, db *template.DBImp[T], enforcer Enforcer, action model.Action, accountRef primitive.ObjectID, query builder.Query) error {
l, err := db.ListPermissionBound(ctx, query)
if err != nil {
db.Logger.Warn("Error occured while checking access rights", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
return err
}
if len(l) == 0 {
db.Logger.Debug("Access denied", mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
return merrors.AccessDenied(db.Repository.Collection(), string(action), primitive.NilObjectID)
}
for _, item := range l {
db.Logger.Debug("Object found", mzap.ObjRef("object_ref", *item.GetID()),
mzap.ObjRef("organization_ref", item.GetOrganizationRef()),
mzap.ObjRef("permission_ref", item.GetPermissionRef()),
zap.String("collection", item.Collection()))
}
res, err := enforcer.EnforceBatch(ctx, l, accountRef, action)
if err != nil {
db.Logger.Warn("Failed to enforce permission", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
}
for objectRef, hasPermission := range res {
if !hasPermission {
db.Logger.Info("Permission denied for object during reordering", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionUpdate)))
return merrors.AccessDenied(db.Repository.Collection(), string(action), objectRef)
}
}
return nil
}
func enforceObjectByRef[T model.PermissionBoundStorable](ctx context.Context, db *template.DBImp[T], enforcer Enforcer, action model.Action, accountRef, objectRef primitive.ObjectID) error {
err := enforceObject(ctx, db, enforcer, action, accountRef, repository.IDFilter(objectRef))
if err != nil {
if errors.Is(err, merrors.ErrAccessDenied) {
db.Logger.Debug("Access denied", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return merrors.AccessDenied(db.Repository.Collection(), string(action), objectRef)
} else {
db.Logger.Warn("Error occurred while checking permissions", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
}
}
return err
}

29
api/pkg/auth/indexable.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// IndexableDB implements reordering with permission checking
type IndexableDB[T storable.Storable] interface {
// Reorder implements reordering with permission checking using EnforceBatch
Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
}
// NewIndexableDB creates a new auth.IndexableDB instance
func NewIndexableDB[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
enforcer Enforcer,
createEmpty func() T,
getIndexable func(T) *model.Indexable,
) IndexableDB[T] {
return newIndexableDBImp(repo, logger, enforcer, createEmpty, getIndexable)
}

View File

@@ -0,0 +1,182 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// IndexableDB implements reordering with permission checking
type indexableDBImp[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
enforcer Enforcer
createEmpty func() T
getIndexable func(T) *model.Indexable
}
// NewIndexableDB creates a new auth.IndexableDB instance
func newIndexableDBImp[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
enforcer Enforcer,
createEmpty func() T,
getIndexable func(T) *model.Indexable,
) IndexableDB[T] {
return &indexableDBImp[T]{
repo: repo,
logger: logger.Named("indexable"),
enforcer: enforcer,
createEmpty: createEmpty,
getIndexable: getIndexable,
}
}
// Reorder implements reordering with permission checking using EnforceBatch
func (db *indexableDBImp[T]) Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
// Get current object to find its index
obj := db.createEmpty()
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for reordering", zap.Error(err), zap.Int("new_index", newIndex),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
// Extract index from the object
indexable := db.getIndexable(obj)
currentIndex := indexable.Index
if currentIndex == newIndex {
db.logger.Debug("No reordering needed - same index", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
return nil // No change needed
}
// Determine which objects will be affected by the reordering
var affectedObjects []model.PermissionBoundStorable
if currentIndex < newIndex {
// Moving down: items between currentIndex+1 and newIndex will be shifted up by -1
reorderFilter := filter.
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
And(repository.IndexOpFilter(newIndex, builder.Lte))
// Get all affected objects using ListPermissionBound
objects, err := db.repo.ListPermissionBound(ctx, reorderFilter)
if err != nil {
db.logger.Warn("Failed to get affected objects for reordering (moving down)",
zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
return err
}
affectedObjects = append(affectedObjects, objects...)
db.logger.Debug("Found affected objects for moving down",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(objects)))
} else {
// Moving up: items between newIndex and currentIndex-1 will be shifted down by +1
reorderFilter := filter.
And(repository.IndexOpFilter(newIndex, builder.Gte)).
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
// Get all affected objects using ListPermissionBound
objects, err := db.repo.ListPermissionBound(ctx, reorderFilter)
if err != nil {
db.logger.Warn("Failed to get affected objects for reordering (moving up)", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
return err
}
affectedObjects = append(affectedObjects, objects...)
db.logger.Debug("Found affected objects for moving up", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(objects)))
}
// Add the target object to the list of objects that need permission checking
targetObjects, err := db.repo.ListPermissionBound(ctx, repository.IDFilter(objectRef))
if err != nil {
db.logger.Warn("Failed to get target object for permission checking", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
if len(targetObjects) > 0 {
affectedObjects = append(affectedObjects, targetObjects[0])
}
// Check permissions for all affected objects using EnforceBatch
db.logger.Debug("Checking permissions for reordering", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(affectedObjects)),
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
permissions, err := db.enforcer.EnforceBatch(ctx, affectedObjects, accountRef, model.ActionUpdate)
if err != nil {
db.logger.Warn("Failed to check permissions for reordering", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(affectedObjects)))
return merrors.Internal("failed to check permissions for reordering")
}
// Verify all objects have update permission
for resObjectRef, hasPermission := range permissions {
if !hasPermission {
db.logger.Info("Permission denied for object during reordering", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionUpdate)))
return merrors.AccessDenied(db.repo.Collection(), string(model.ActionUpdate), resObjectRef)
}
}
db.logger.Debug("All permissions granted, proceeding with reordering", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("permission_count", len(permissions)))
// All permissions checked, proceed with reordering
if currentIndex < newIndex {
// Moving down: shift items between currentIndex+1 and newIndex up by -1
patch := repository.Patch().Inc(repository.IndexField(), -1)
reorderFilter := filter.
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
And(repository.IndexOpFilter(newIndex, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Warn("Failed to shift objects during reordering (moving down)", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex), zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving down)", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("updated_count", updatedCount))
} else {
// Moving up: shift items between newIndex and currentIndex-1 down by +1
patch := repository.Patch().Inc(repository.IndexField(), 1)
reorderFilter := filter.
And(repository.IndexOpFilter(newIndex, builder.Gte)).
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Warn("Failed to shift objects during reordering (moving up)", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex), zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving up)", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("updated_count", updatedCount))
}
// Update the target object to new index
if err := db.repo.Patch(ctx, objectRef, repository.Patch().Set(repository.IndexField(), newIndex)); err != nil {
db.logger.Warn("Failed to update target object index", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
return err
}
db.logger.Debug("Successfully reordered object with permission checking",
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("old_index", currentIndex),
zap.Int("new_index", newIndex), zap.Int("affected_count", len(affectedObjects)))
return nil
}

View File

@@ -0,0 +1,23 @@
package casbin
import (
"fmt"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
)
func stringToAction(actionStr string) (model.Action, error) {
switch actionStr {
case string(model.ActionCreate):
return model.ActionCreate, nil
case string(model.ActionRead):
return model.ActionRead, nil
case string(model.ActionUpdate):
return model.ActionUpdate, nil
case string(model.ActionDelete):
return model.ActionDelete, nil
default:
return "", merrors.InvalidArgument(fmt.Sprintf("invalid action: %s", actionStr))
}
}

View File

@@ -0,0 +1,126 @@
package casbin
import (
"os"
"time"
mongodbadapter "github.com/casbin/mongodb-adapter/v3"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
)
type AdapterConfig struct {
DatabaseName *string `mapstructure:"database_name"`
DatabaseNameEnv *string `mapstructure:"database_name_env"`
CollectionName *string `mapstructure:"collection_name"`
CollectionNameEnv *string `mapstructure:"collection_name_env"`
TimeoutSeconds *int `mapstructure:"timeout_seconds"`
TimeoutSecondsEnv *string `mapstructure:"timeout_seconds_env"`
IsFiltered *bool `mapstructure:"is_filtered"`
IsFilteredEnv *string `mapstructure:"is_filtered_env"`
}
type Config struct {
ModelPath *string `mapstructure:"model_path"`
ModelPathEnv *string `mapstructure:"model_path_env"`
Adapter *AdapterConfig `mapstructure:"adapter"`
}
type EnforcerConfig struct {
ModelPath string
Adapter *mongodbadapter.AdapterConfig
}
func getEnvValue(logger mlogger.Logger, varName, envVarName string, value, envValue *string) string {
if value != nil && envValue != nil {
logger.Warn("Both variable and environment variable are set, using environment variable value",
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.String("value", *value), zap.String("env_value", os.Getenv(*envValue)))
}
if envValue != nil {
return os.Getenv(*envValue)
}
if value != nil {
return *value
}
return ""
}
func getEnvIntValue(logger mlogger.Logger, varName, envVarName string, value *int, envValue *string) int {
if value != nil && envValue != nil {
logger.Warn("Both variable and environment variable are set, using environment variable value",
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.Int("value", *value), zap.String("env_value", os.Getenv(*envValue)))
}
if envValue != nil {
envStr := os.Getenv(*envValue)
if envStr != "" {
if parsed, err := time.ParseDuration(envStr + "s"); err == nil {
return int(parsed.Seconds())
}
logger.Warn("Invalid environment variable value for timeout", zap.String("environment_variable", envVarName), zap.String("value", envStr))
}
}
if value != nil {
return *value
}
return 30 // Default timeout in seconds
}
func getEnvBoolValue(logger mlogger.Logger, varName, envVarName string, value *bool, envValue *string) bool {
if value != nil && envValue != nil {
logger.Warn("Both variable and environment variable are set, using environment variable value",
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.Bool("value", *value), zap.String("env_value", os.Getenv(*envValue)))
}
if envValue != nil {
envStr := os.Getenv(*envValue)
if envStr == "true" || envStr == "1" {
return true
} else if envStr == "false" || envStr == "0" {
return false
}
logger.Warn("Invalid environment variable value for boolean", zap.String("environment_variable", envVarName), zap.String("value", envStr))
}
if value != nil {
return *value
}
return false // Default for boolean
}
func PrepareConfig(logger mlogger.Logger, config *Config) (*EnforcerConfig, error) {
if config == nil {
return nil, merrors.Internal("No configuration provided")
}
adapter := &mongodbadapter.AdapterConfig{
DatabaseName: getEnvValue(logger, "database_name", "database_name_env", config.Adapter.DatabaseName, config.Adapter.DatabaseNameEnv),
CollectionName: getEnvValue(logger, "collection_name", "collection_name_env", config.Adapter.CollectionName, config.Adapter.CollectionNameEnv),
Timeout: time.Duration(getEnvIntValue(logger, "timeout_seconds", "timeout_seconds_env", config.Adapter.TimeoutSeconds, config.Adapter.TimeoutSecondsEnv)) * time.Second,
IsFiltered: getEnvBoolValue(logger, "is_filtered", "is_filtered_env", config.Adapter.IsFiltered, config.Adapter.IsFilteredEnv),
}
if len(adapter.DatabaseName) == 0 {
logger.Error("Database name is not set")
return nil, merrors.InvalidArgument("database name must be provided")
}
path := getEnvValue(logger, "model_path", "model_path_env", config.ModelPath, config.ModelPathEnv)
logger.Info("Configuration prepared",
zap.String("model_path", path),
zap.String("database_name", adapter.DatabaseName),
zap.String("collection_name", adapter.CollectionName),
zap.Duration("timeout", adapter.Timeout),
zap.Bool("is_filtered", adapter.IsFiltered),
)
return &EnforcerConfig{ModelPath: path, Adapter: adapter}, nil
}

View File

@@ -0,0 +1,206 @@
// casbin_enforcer.go
package casbin
import (
"context"
"github.com/casbin/casbin/v2"
"github.com/tech/sendico/pkg/auth/anyobject"
cc "github.com/tech/sendico/pkg/auth/internal/casbin/config"
"github.com/tech/sendico/pkg/auth/internal/casbin/serialization"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"github.com/mitchellh/mapstructure"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
// CasbinEnforcer implements the Enforcer interface using Casbin.
type CasbinEnforcer struct {
logger mlogger.Logger
enforcer *casbin.Enforcer
roleSerializer serialization.Role
permissionSerializer serialization.Policy
}
// NewCasbinEnforcer initializes a new CasbinEnforcer with a MongoDB adapter, logger, and PolicySerializer.
// The 'domain' parameter is no longer stored internally, as the interface requires passing a domainRef per method call.
func NewEnforcer(
logger mlogger.Logger,
client *mongo.Client,
settings model.SettingsT,
) (*CasbinEnforcer, error) {
var config cc.Config
if err := mapstructure.Decode(settings, &config); err != nil {
logger.Warn("Failed to decode Casbin configuration", zap.Error(err), zap.Any("settings", settings))
return nil, merrors.Internal("failed to decode Casbin configuration")
}
// Create a Casbin adapter + enforcer from your config and client.
l := logger.Named("enforcer")
e, err := createAdapter(l, &config, client)
if err != nil {
logger.Warn("Failed to create Casbin enforcer", zap.Error(err))
return nil, merrors.Internal("failed to create Casbin enforcer")
}
logger.Info("Casbin enforcer created")
return &CasbinEnforcer{
logger: l,
enforcer: e,
permissionSerializer: serialization.NewPolicySerializer(),
roleSerializer: serialization.NewRoleSerializer(),
}, nil
}
// Enforce checks if a user has the specified action permission on an object within a domain.
func (c *CasbinEnforcer) Enforce(
_ context.Context,
permissionRef, accountRef, organizationRef, objectRef primitive.ObjectID,
action model.Action,
) (bool, error) {
// Convert ObjectIDs to strings for Casbin
account := accountRef.Hex()
organization := organizationRef.Hex()
permission := permissionRef.Hex()
object := anyobject.ID
if objectRef != primitive.NilObjectID {
object = objectRef.Hex()
}
act := string(action)
c.logger.Debug("Enforcing policy",
zap.String("account", account), zap.String("organization", organization),
zap.String("permission", permission), zap.String("object", object),
zap.String("action", act))
// Perform the enforcement
result, err := c.enforcer.Enforce(account, organization, permission, object, act)
if err != nil {
c.logger.Warn("Failed to enforce policy", zap.Error(err),
zap.String("account", account), zap.String("organization", organization),
zap.String("permission", permission), zap.String("object", object),
zap.String("action", act))
return false, err
}
c.logger.Debug("Policy enforcement result", zap.Bool("result", result))
return result, nil
}
// EnforceBatch checks a users permission for multiple objects at once.
// It returns a map from objectRef -> boolean indicating whether access is granted.
func (c *CasbinEnforcer) EnforceBatch(
ctx context.Context,
objectRefs []model.PermissionBoundStorable,
accountRef primitive.ObjectID,
action model.Action,
) (map[primitive.ObjectID]bool, error) {
results := make(map[primitive.ObjectID]bool, len(objectRefs))
for _, desc := range objectRefs {
ok, err := c.Enforce(ctx, desc.GetPermissionRef(), accountRef, desc.GetOrganizationRef(), *desc.GetID(), action)
if err != nil {
c.logger.Warn("Failed to enforce", zap.Error(err), mzap.ObjRef("permission_ref", desc.GetPermissionRef()),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("action", string(action)))
return nil, err
}
results[*desc.GetID()] = ok
}
return results, nil
}
// GetRoles retrieves all roles assigned to the user within the domain.
func (c *CasbinEnforcer) GetRoles(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, error) {
sub := accountRef.Hex()
dom := orgRef.Hex()
c.logger.Debug("Fetching roles for user", zap.String("subject", sub), zap.String("domain", dom))
// Get all roles for the user in the domain
sroles, err := c.enforcer.GetFilteredGroupingPolicy(0, sub, "", dom)
if err != nil {
c.logger.Warn("Failed to get roles from policies", zap.Error(err),
zap.String("account_ref", sub), zap.String("organization_ref", dom),
)
return nil, merrors.Internal("failed to fetch roles from policies")
}
roles := make([]model.Role, 0, len(sroles))
for _, srole := range sroles {
role, err := c.roleSerializer.Deserialize(srole)
if err != nil {
c.logger.Warn("Failed to deserialize role", zap.Error(err))
return nil, err
}
roles = append(roles, *role)
}
c.logger.Debug("Roles fetched successfully", zap.Int("count", len(roles)))
return roles, nil
}
// GetPermissions retrieves all effective policies for the user within the domain.
func (c *CasbinEnforcer) GetPermissions(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
c.logger.Debug("Fetching policies for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
// Step 1: Retrieve all roles assigned to the user within the domain
roles, err := c.GetRoles(ctx, accountRef, orgRef)
if err != nil {
c.logger.Warn("Failed to get roles", zap.Error(err))
return nil, nil, err
}
// Map to hold unique policies
permissionsMap := make(map[string]*model.Permission)
for _, role := range roles {
// Step 2a: Retrieve all policies associated with the role within the domain
policies, err := c.enforcer.GetFilteredPolicy(0, role.DescriptionRef.Hex())
if err != nil {
c.logger.Warn("Failed to get policies for role", zap.Error(err), mzap.ObjRef("role_ref", role.DescriptionRef))
continue
}
// Step 2b: Process each policy to extract Permission, Action, and Effect
for _, policy := range policies {
if len(policy) < 5 {
c.logger.Warn("Incomplete policy encountered", zap.Strings("policy", policy))
continue // Ensure the policy line has enough fields
}
// Deserialize the policy using
deserializedPolicy, err := c.permissionSerializer.Deserialize(policy)
if err != nil {
c.logger.Warn("Failed to deserialize policy", zap.Error(err), zap.Strings("policy", policy))
continue
}
// Construct a unique key combining Permission ID and Action to prevent duplicates
policyKey := deserializedPolicy.DescriptionRef.Hex() + ":" + string(deserializedPolicy.Effect.Action)
if _, exists := permissionsMap[policyKey]; exists {
continue // Policy-action pair already accounted for
}
// Add the Policy to the map
permissionsMap[policyKey] = &model.Permission{
RolePolicy: *deserializedPolicy,
AccountRef: accountRef,
}
c.logger.Debug("Policy added to policyMap", zap.Any("policy_key", policyKey))
}
}
// Convert the map to a slice
permissions := make([]model.Permission, 0, len(permissionsMap))
for _, permission := range permissionsMap {
permissions = append(permissions, *permission)
}
c.logger.Debug("Permissions fetched successfully", zap.Int("count", len(permissions)))
return roles, permissions, nil
}

View File

@@ -0,0 +1,34 @@
package casbin
import (
"github.com/casbin/casbin/v2"
mongodbadapter "github.com/casbin/mongodb-adapter/v3"
cc "github.com/tech/sendico/pkg/auth/internal/casbin/config"
"github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
func createAdapter(logger mlogger.Logger, config *cc.Config, client *mongo.Client) (*casbin.Enforcer, error) {
dbc, err := cc.PrepareConfig(logger, config)
if err != nil {
logger.Warn("Failed to prepare database configuration", zap.Error(err))
return nil, err
}
adapter, err := mongodbadapter.NewAdapterByDB(client, dbc.Adapter)
if err != nil {
logger.Warn("Failed to create DB adapter", zap.Error(err))
return nil, err
}
e, err := casbin.NewEnforcer(dbc.ModelPath, adapter, NewCasbinLogger(logger))
if err != nil {
logger.Warn("Failed to create permissions enforcer", zap.Error(err))
return nil, err
}
e.EnableAutoSave(true)
// No need to manually load policy. Casbin does it for us
return e, nil
}

View File

@@ -0,0 +1,61 @@
package casbin
import (
"strings"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
)
// CasbinZapLogger wraps a zap.Logger to implement Casbin's Logger interface.
type CasbinZapLogger struct {
logger mlogger.Logger
}
// NewCasbinLogger constructs a new CasbinZapLogger.
func NewCasbinLogger(logger mlogger.Logger) *CasbinZapLogger {
return &CasbinZapLogger{
logger: logger.Named("driver"),
}
}
// EnableLog enables or disables logging.
func (l *CasbinZapLogger) EnableLog(_ bool) {
// ignore
}
// IsEnabled returns whether logging is currently enabled.
func (l *CasbinZapLogger) IsEnabled() bool {
return true
}
// LogModel is called by Casbin when loading model settings (you can customize if you want).
func (l *CasbinZapLogger) LogModel(m [][]string) {
l.logger.Info("Model loaded", zap.Any("model", m))
}
func (l *CasbinZapLogger) LogPolicy(m map[string][][]string) {
l.logger.Info("Policy loaded", zap.Int("entries", len(m)))
}
func (l *CasbinZapLogger) LogError(err error, msg ...string) {
// If no custom message was passed, log a generic one
if len(msg) == 0 {
l.logger.Warn("Error occurred", zap.Error(err))
return
}
// Otherwise, join any provided messages and include them
l.logger.Warn(strings.Join(msg, " "), zap.Error(err))
}
// LogEnforce is called by Casbin to log each Enforce() call if logging is enabled.
func (l *CasbinZapLogger) LogEnforce(matcher string, request []any, result bool, explains [][]string) {
l.logger.Debug("Enforcing policy...", zap.String("matcher", matcher), zap.Any("request", request),
zap.Bool("result", result), zap.Any("explains", explains))
}
// LogRole is called by Casbin when role manager adds or deletes a role.
func (l *CasbinZapLogger) LogRole(roles []string) {
l.logger.Debug("Changing roles...", zap.Strings("roles", roles))
}

View File

@@ -0,0 +1,54 @@
// package casbin
package casbin
import (
"context"
"github.com/tech/sendico/pkg/auth/management"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.uber.org/zap"
)
// CasbinManager implements the auth.Manager interface by aggregating Role and Permission managers.
type CasbinManager struct {
logger mlogger.Logger
roleManager management.Role
permManager management.Permission
}
// NewManager creates a new CasbinManager with specified domains and role-domain mappings.
func NewManager(
l mlogger.Logger,
pdb policy.DB,
rdb role.DB,
enforcer *CasbinEnforcer,
settings model.SettingsT,
) (*CasbinManager, error) {
logger := l.Named("manager")
var pdesc model.PolicyDescription
if err := pdb.GetBuiltInPolicy(context.Background(), "roles", &pdesc); err != nil {
logger.Warn("Failed to fetch roles permission reference", zap.Error(err))
return nil, err
}
return &CasbinManager{
logger: logger,
roleManager: NewRoleManager(logger, enforcer, pdesc.ID, rdb),
permManager: NewPermissionManager(logger, enforcer),
}, nil
}
// Permission returns the Permission manager.
func (m *CasbinManager) Permission() management.Permission {
return m.permManager
}
// Role returns the Role manager.
func (m *CasbinManager) Role() management.Role {
return m.roleManager
}

View File

@@ -0,0 +1,54 @@
######################################################
# Request Definition
######################################################
[request_definition]
# Explanation:
# - `accountRef`: The account (user) making the request.
# - `organizationRef`: The organization in which the role applies.
# - `permissionRef`: The specific permission being requested.
# - `objectRef`: The object/resource being accessed (specific object or all objects).
# - `action`: The action being requested (CRUD: read, write, update, delete).
r = accountRef, organizationRef, permissionRef, objectRef, action
######################################################
# Policy Definition
######################################################
[policy_definition]
# Explanation:
# - `roleRef`: The role to which the policy is assigned.
# - `organizationRef`: The organization in which the role applies.
# - `permissionRef`: The permission associated with the policy.
# - `objectRef`: The specific object/resource the policy applies to (or all objects).
# - `action`: The CRUD action permitted or denied.
# - `eft`: Effect of the policy (`allow` or `deny`).
p = roleRef, organizationRef, permissionRef, objectRef, action, eft
######################################################
# Role Definition
######################################################
[role_definition]
# Explanation:
# - Maps `accountRef` (user) to `roleRef` (role) within `organizationRef` (scope).
# Casbin requires underscores for placeholders, so we do not literally use accountRef, roleRef, etc. here.
g = _, _, _
######################################################
# Policy Effect
######################################################
[policy_effect]
# Explanation:
# - Grants access if any `allow` policy matches and no `deny` policies match.
e = some(where (p.eft == allow)) && !some(where (p.eft == deny))
######################################################
# Matchers
######################################################
[matchers]
# Explanation:
# - Checks if the user (accountRef) belongs to the roleRef within an organizationRef via `g()`.
# - Ensures the organizationRef, permissionRef, objectRef, and action match the policy.
m = g(r.accountRef, p.roleRef, r.organizationRef) && r.organizationRef == p.organizationRef && r.permissionRef == p.permissionRef && (p.objectRef == r.objectRef || p.objectRef == "*") && r.action == p.action

View File

@@ -0,0 +1,167 @@
package casbin
import (
"context"
"github.com/tech/sendico/pkg/auth/anyobject"
"github.com/tech/sendico/pkg/auth/internal/casbin/serialization"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// CasbinPermissionManager manages permissions using Casbin.
type CasbinPermissionManager struct {
logger mlogger.Logger // Logger for logging operations
enforcer *CasbinEnforcer // Casbin enforcer for managing policies
serializer serialization.Policy // Serializer for converting policies to/from Casbin
}
// GrantToRole adds a permission to a role in Casbin.
func (m *CasbinPermissionManager) GrantToRole(ctx context.Context, policy *model.RolePolicy) error {
objRef := anyobject.ID
if (policy.ObjectRef != nil) && (*policy.ObjectRef != primitive.NilObjectID) {
objRef = policy.ObjectRef.Hex()
}
m.logger.Debug("Granting permission to role",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)),
zap.String("effect", string(policy.Effect.Effect)),
)
// Serialize permission
serializedPolicy, err := m.serializer.Serialize(policy)
if err != nil {
m.logger.Error("Failed to serialize permission while granting permission", zap.Error(err),
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
mzap.ObjRef("organization_ref", policy.OrganizationRef),
)
return err
}
// Add policy to Casbin
added, err := m.enforcer.enforcer.AddPolicy(serializedPolicy...)
if err != nil {
m.logger.Error("Failed to add policy to Casbin", zap.Error(err))
return err
}
if added {
m.logger.Info("Policy added to Casbin",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
)
} else {
m.logger.Warn("Policy already exists in Casbin",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
)
}
return nil
}
// RevokeFromRole removes a permission from a role in Casbin.
func (m *CasbinPermissionManager) RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error {
objRef := anyobject.ID
if policy.ObjectRef != nil {
objRef = policy.ObjectRef.Hex()
}
m.logger.Debug("Revoking permission from role",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)),
zap.String("effect", string(policy.Effect.Effect)),
)
// Serialize policy
serializedPolicy, err := m.serializer.Serialize(policy)
if err != nil {
m.logger.Error("Failed to serialize policy while revoking permission from role",
zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("policy_ref", policy.DescriptionRef))
return err
}
// Remove policy from Casbin
removed, err := m.enforcer.enforcer.RemovePolicy(serializedPolicy...)
if err != nil {
m.logger.Error("Failed to remove policy from Casbin", zap.Error(err))
return err
}
if removed {
m.logger.Info("Policy removed from Casbin",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
)
} else {
m.logger.Warn("Policy does not exist in Casbin",
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef),
zap.String("object_ref", objRef),
)
}
return nil
}
// GetPolicies retrieves all policies for a specific role.
func (m *CasbinPermissionManager) GetPolicies(
ctx context.Context,
roleRef primitive.ObjectID,
) ([]model.RolePolicy, error) {
m.logger.Debug("Fetching policies for role", mzap.ObjRef("role_ref", roleRef))
// Retrieve Casbin policies for the role
policies, err := m.enforcer.enforcer.GetFilteredPolicy(0, roleRef.Hex())
if err != nil {
m.logger.Warn("Failed to get policies", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
return nil, err
}
if len(policies) == 0 {
m.logger.Info("No policies found for role", mzap.ObjRef("role_ref", roleRef))
return nil, merrors.NoData("no policies")
}
// Deserialize policies
var result []model.RolePolicy
for _, policy := range policies {
permission, err := m.serializer.Deserialize(policy)
if err != nil {
m.logger.Warn("Failed to deserialize policy", zap.Error(err), zap.String("policy", policy[0]))
continue
}
result = append(result, *permission)
}
m.logger.Debug("Policies fetched successfully", mzap.ObjRef("role_ref", roleRef), zap.Int("count", len(result)))
return result, nil
}
// Save persists changes to the Casbin policy store.
func (m *CasbinPermissionManager) Save() error {
if err := m.enforcer.enforcer.SavePolicy(); err != nil {
m.logger.Error("Failed to save policies in Casbin", zap.Error(err))
return err
}
m.logger.Info("Policies successfully saved in Casbin")
return nil
}
func NewPermissionManager(logger mlogger.Logger, enforcer *CasbinEnforcer) *CasbinPermissionManager {
return &CasbinPermissionManager{
logger: logger.Named("permission"),
enforcer: enforcer,
serializer: serialization.NewPolicySerializer(),
}
}

View File

@@ -0,0 +1,209 @@
package casbin
import (
"context"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// RoleManager manages roles using Casbin.
type RoleManager struct {
logger mlogger.Logger
enforcer *CasbinEnforcer
rdb role.DB
rolePermissionRef primitive.ObjectID
}
// NewRoleManager creates a new RoleManager.
func NewRoleManager(logger mlogger.Logger, enforcer *CasbinEnforcer, rolePermissionRef primitive.ObjectID, rdb role.DB) *RoleManager {
return &RoleManager{
logger: logger.Named("role"),
enforcer: enforcer,
rdb: rdb,
rolePermissionRef: rolePermissionRef,
}
}
// validateObjectIDs ensures that all provided ObjectIDs are non-zero.
func (rm *RoleManager) validateObjectIDs(ids ...primitive.ObjectID) error {
for _, id := range ids {
if id.IsZero() {
return merrors.InvalidArgument("Object references cannot be zero")
}
}
return nil
}
// removePolicies removes policies based on the provided filter and logs the results.
func (rm *RoleManager) removePolicies(policyType, role string, roleRef primitive.ObjectID) error {
filterIndex := 1
if policyType == "permission" {
filterIndex = 0
}
policies, err := rm.enforcer.enforcer.GetFilteredPolicy(filterIndex, role)
if err != nil {
rm.logger.Warn("Failed to fetch "+policyType+" policies", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
return err
}
for _, policy := range policies {
args := make([]any, len(policy))
for i, v := range policy {
args[i] = v
}
var removed bool
var removeErr error
if policyType == "grouping" {
removed, removeErr = rm.enforcer.enforcer.RemoveGroupingPolicy(args...)
} else {
removed, removeErr = rm.enforcer.enforcer.RemovePolicy(args...)
}
if removeErr != nil {
rm.logger.Warn("Failed to remove "+policyType+" policy for role", zap.Error(removeErr), mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy))
return removeErr
}
if removed {
rm.logger.Info("Removed "+policyType+" policy for role", mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy))
}
}
return nil
}
// fetchRolesFromPolicies retrieves and converts policies to roles.
func (rm *RoleManager) fetchRolesFromPolicies(policies [][]string, orgRef primitive.ObjectID) []model.RoleDescription {
roles := make([]model.RoleDescription, 0, len(policies))
for _, policy := range policies {
if len(policy) < 2 {
continue
}
roleID, err := primitive.ObjectIDFromHex(policy[1])
if err != nil {
rm.logger.Warn("Invalid role ID", zap.String("roleID", policy[1]))
continue
}
roles = append(roles, model.RoleDescription{Base: storable.Base{ID: roleID}, OrganizationRef: orgRef})
}
return roles
}
// Create creates a new role in an organization.
func (rm *RoleManager) Create(ctx context.Context, orgRef primitive.ObjectID, description *model.Describable) (*model.RoleDescription, error) {
if err := rm.validateObjectIDs(orgRef); err != nil {
return nil, err
}
role := &model.RoleDescription{
Describable: *description,
OrganizationRef: orgRef,
}
if err := rm.rdb.Create(ctx, role); err != nil {
rm.logger.Warn("Failed to create role", zap.Error(err), mzap.ObjRef("organiztion_ref", orgRef))
return nil, err
}
rm.logger.Info("Role created successfully", mzap.StorableRef(role), mzap.ObjRef("organization_ref", orgRef))
return role, nil
}
// Assign assigns a role to a user in the given organization.
func (rm *RoleManager) Assign(ctx context.Context, role *model.Role) error {
if err := rm.validateObjectIDs(role.DescriptionRef, role.AccountRef, role.OrganizationRef); err != nil {
return err
}
sub := role.AccountRef.Hex()
roleID := role.DescriptionRef.Hex()
domain := role.OrganizationRef.Hex()
added, err := rm.enforcer.enforcer.AddGroupingPolicy(sub, roleID, domain)
return rm.logPolicyResult("assign", added, err, role.DescriptionRef, role.AccountRef, role.OrganizationRef)
}
// Delete removes a role entirely and cleans up associated Casbin policies.
func (rm *RoleManager) Delete(ctx context.Context, roleRef primitive.ObjectID) error {
if err := rm.validateObjectIDs(roleRef); err != nil {
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
return err
}
if err := rm.rdb.Delete(ctx, roleRef); err != nil {
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
return err
}
role := roleRef.Hex()
// Remove grouping policies
if err := rm.removePolicies("grouping", role, roleRef); err != nil {
return err
}
// Remove permission policies
if err := rm.removePolicies("permission", role, roleRef); err != nil {
return err
}
// // Save changes
// if err := rm.enforcer.enforcer.SavePolicy(); err != nil {
// rm.logger.Warn("Failed to save Casbin policies after role deletion",
// zap.Error(err),
// mzap.ObjRef("role_ref", roleRef),
// )
// return err
// }
rm.logger.Info("Role deleted successfully along with associated policies", mzap.ObjRef("role_ref", roleRef))
return nil
}
// Revoke removes a role from a user.
func (rm *RoleManager) Revoke(ctx context.Context, roleRef, accountRef, orgRef primitive.ObjectID) error {
if err := rm.validateObjectIDs(roleRef, accountRef, orgRef); err != nil {
return err
}
sub := accountRef.Hex()
role := roleRef.Hex()
domain := orgRef.Hex()
removed, err := rm.enforcer.enforcer.RemoveGroupingPolicy(sub, role, domain)
return rm.logPolicyResult("revoke", removed, err, roleRef, accountRef, orgRef)
}
// logPolicyResult logs results for Assign and Revoke.
func (rm *RoleManager) logPolicyResult(action string, result bool, err error, roleRef, accountRef, orgRef primitive.ObjectID) error {
if err != nil {
rm.logger.Warn("Failed to "+action+" role", zap.Error(err), mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
return err
}
msg := "Role " + action + "ed successfully"
if !result {
msg = "Role already " + action + "ed"
}
rm.logger.Info(msg, mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
return nil
}
// List retrieves all roles in an organization or all roles if orgRef is zero.
func (rm *RoleManager) List(ctx context.Context, orgRef primitive.ObjectID) ([]model.RoleDescription, error) {
domain := orgRef.Hex()
groupingPolicies, err := rm.enforcer.enforcer.GetFilteredGroupingPolicy(2, domain)
if err != nil {
rm.logger.Warn("Failed to fetch grouping policies", zap.Error(err), mzap.ObjRef("organization_ref", orgRef))
return nil, err
}
roles := rm.fetchRolesFromPolicies(groupingPolicies, orgRef)
rm.logger.Info("Retrieved roles for organization", mzap.ObjRef("organization_ref", orgRef), zap.Int("count", len(roles)))
return roles, nil
}

View File

@@ -0,0 +1,81 @@
package serializationimp
import (
"github.com/tech/sendico/pkg/auth/anyobject"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// PolicySerializer implements CasbinSerializer for Permission.
type PolicySerializer struct{}
// Serialize converts a Permission object into a Casbin policy.
func (s *PolicySerializer) Serialize(entity *model.RolePolicy) ([]any, error) {
if entity.RoleDescriptionRef.IsZero() ||
entity.OrganizationRef.IsZero() ||
entity.DescriptionRef.IsZero() || // Ensure permissionRef is valid
entity.Effect.Action == "" || // Ensure action is not empty
entity.Effect.Effect == "" { // Ensure effect (eft) is not empty
return nil, merrors.InvalidArgument("permission contains invalid object references or missing fields")
}
objectRef := anyobject.ID
if entity.ObjectRef != nil {
objectRef = entity.ObjectRef.Hex()
}
return []any{
entity.RoleDescriptionRef.Hex(), // Maps to p.roleRef
entity.OrganizationRef.Hex(), // Maps to p.organizationRef
entity.DescriptionRef.Hex(), // Maps to p.permissionRef
objectRef, // Maps to p.objectRef (wildcard if empty)
string(entity.Effect.Action), // Maps to p.action
string(entity.Effect.Effect), // Maps to p.eft
}, nil
}
// Deserialize converts a Casbin policy into a Permission object.
func (s *PolicySerializer) Deserialize(policy []string) (*model.RolePolicy, error) {
if len(policy) != 6 { // Ensure policy has the correct number of fields
return nil, merrors.Internal("invalid policy format")
}
roleRef, err := primitive.ObjectIDFromHex(policy[0])
if err != nil {
return nil, merrors.InvalidArgument("invalid roleRef in policy")
}
organizationRef, err := primitive.ObjectIDFromHex(policy[1])
if err != nil {
return nil, merrors.InvalidArgument("invalid organizationRef in policy")
}
permissionRef, err := primitive.ObjectIDFromHex(policy[2])
if err != nil {
return nil, merrors.InvalidArgument("invalid permissionRef in policy")
}
// Handle wildcard for ObjectRef
var objectRef *primitive.ObjectID
if policy[3] != anyobject.ID {
ref, err := primitive.ObjectIDFromHex(policy[3])
if err != nil {
return nil, merrors.InvalidArgument("invalid objectRef in policy")
}
objectRef = &ref
}
return &model.RolePolicy{
RoleDescriptionRef: roleRef,
Policy: model.Policy{
OrganizationRef: organizationRef,
DescriptionRef: permissionRef,
ObjectRef: objectRef,
Effect: model.ActionEffect{
Action: model.Action(policy[4]),
Effect: model.Effect(policy[5]),
},
},
}, nil
}

View File

@@ -0,0 +1,57 @@
package serializationimp
import (
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// RoleSerializer implements CasbinSerializer for Role.
type RoleSerializer struct{}
// Serialize converts a Role object into a Casbin grouping policy.
func (s *RoleSerializer) Serialize(entity *model.Role) ([]any, error) {
// Validate required fields
if entity.AccountRef.IsZero() || entity.DescriptionRef.IsZero() || entity.OrganizationRef.IsZero() {
return nil, merrors.InvalidArgument("role contains invalid object references")
}
return []any{
entity.AccountRef.Hex(), // Maps to g(_, _, _) accountRef
entity.DescriptionRef.Hex(), // Maps to g(_, _, _) roleRef
entity.OrganizationRef.Hex(), // Maps to g(_, _, _) organizationRef
}, nil
}
// Deserialize converts a Casbin grouping policy into a Role object.
func (s *RoleSerializer) Deserialize(policy []string) (*model.Role, error) {
// Ensure the policy has exactly 3 fields
if len(policy) != 3 {
return nil, merrors.Internal("invalid grouping policy format")
}
// Parse accountRef
accountRef, err := primitive.ObjectIDFromHex(policy[0])
if err != nil {
return nil, merrors.InvalidArgument("invalid accountRef in grouping policy")
}
// Parse roleDescriptionRef (roleRef)
roleDescriptionRef, err := primitive.ObjectIDFromHex(policy[1])
if err != nil {
return nil, merrors.InvalidArgument("invalid roleRef in grouping policy")
}
// Parse organizationRef
organizationRef, err := primitive.ObjectIDFromHex(policy[2])
if err != nil {
return nil, merrors.InvalidArgument("invalid organizationRef in grouping policy")
}
// Return the constructed Role object
return &model.Role{
AccountRef: accountRef,
DescriptionRef: roleDescriptionRef,
OrganizationRef: organizationRef,
}, nil
}

View File

@@ -0,0 +1,12 @@
package serialization
import (
serializationimp "github.com/tech/sendico/pkg/auth/internal/casbin/serialization/internal"
"github.com/tech/sendico/pkg/model"
)
type Policy = CasbinSerializer[model.RolePolicy]
func NewPolicySerializer() Policy {
return &serializationimp.PolicySerializer{}
}

View File

@@ -0,0 +1,12 @@
package serialization
import (
serializationimp "github.com/tech/sendico/pkg/auth/internal/casbin/serialization/internal"
"github.com/tech/sendico/pkg/model"
)
type Role = CasbinSerializer[model.Role]
func NewRoleSerializer() Role {
return &serializationimp.RoleSerializer{}
}

View File

@@ -0,0 +1,10 @@
package serialization
// CasbinSerializer defines methods for serializing and deserializing any Casbin-compatible entity.
type CasbinSerializer[T any] interface {
// Serialize converts an entity (Role or Permission) into a Casbin policy.
Serialize(entity *T) ([]any, error)
// Deserialize converts a Casbin policy into an entity (Role or Permission).
Deserialize(policy []string) (*T, error)
}

View File

@@ -0,0 +1,151 @@
package db
import (
"context"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/repository"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type PermissionsDBImp struct {
template.DBImp[*nstructures.PolicyAssignment]
}
func (db *PermissionsDBImp) Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error) {
return mutil.GetObjects[nstructures.PolicyAssignment](
ctx,
db.Logger,
repository.Query().And(
repository.Filter("policy.organizationRef", object.GetOrganizationRef()),
repository.Filter("policy.descriptionRef", object.GetPermissionRef()),
repository.Filter("policy.effect.action", action),
repository.Query().Or(
repository.Filter("policy.objectRef", *object.GetID()),
repository.Filter("policy.objectRef", nil),
),
),
nil,
db.Repository,
)
}
func (db *PermissionsDBImp) PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
return mutil.GetObjects[nstructures.PolicyAssignment](
ctx,
db.Logger,
repository.Query().And(
repository.Filter("roleRef", roleRef),
repository.Filter("policy.descriptionRef", permissionRef),
repository.Filter("policy.effect.action", action),
),
nil,
db.Repository,
)
}
func (db *PermissionsDBImp) Remove(ctx context.Context, policy *model.RolePolicy) error {
objRefFilter := repository.Query().Or(
repository.Filter("policy.objectRef", nil),
repository.Filter("policy.objectRef", primitive.NilObjectID),
)
if policy.ObjectRef != nil {
objRefFilter = repository.Filter("policy.objectRef", *policy.ObjectRef)
}
return db.Repository.DeleteMany(
ctx,
repository.Query().And(
repository.Filter("roleRef", policy.RoleDescriptionRef),
repository.Filter("policy.organizationRef", policy.OrganizationRef),
repository.Filter("policy.descriptionRef", policy.DescriptionRef),
objRefFilter,
repository.Filter("policy.effect.action", policy.Effect.Action),
repository.Filter("policy.effect.effect", policy.Effect.Effect),
),
)
}
func (db *PermissionsDBImp) PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
return mutil.GetObjects[nstructures.PolicyAssignment](
ctx,
db.Logger,
repository.Filter("roleRef", roleRef),
nil,
db.Repository,
)
}
func (db *PermissionsDBImp) PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
if len(roleRefs) == 0 {
db.Logger.Debug("Empty role references list provided, returning empty resposnse")
return []nstructures.PolicyAssignment{}, nil
}
return mutil.GetObjects[nstructures.PolicyAssignment](
ctx,
db.Logger,
repository.Query().And(
repository.Query().In(repository.Field("roleRef"), roleRefs),
repository.Filter("policy.effect.action", action),
),
nil,
db.Repository,
)
}
func NewPoliciesDB(logger mlogger.Logger, db *mongo.Database) (*PermissionsDBImp, error) {
p := &PermissionsDBImp{
DBImp: *template.Create[*nstructures.PolicyAssignment](logger, mservice.PolicyAssignements, db),
}
// faster
// harder
// index
policiesQueryIndex := &ri.Definition{
Keys: []ri.Key{
{Field: "policy.organizationRef", Sort: ri.Asc},
{Field: "policy.descriptionRef", Sort: ri.Asc},
{Field: "policy.effect.action", Sort: ri.Asc},
{Field: "policy.objectRef", Sort: ri.Asc},
},
}
if err := p.DBImp.Repository.CreateIndex(policiesQueryIndex); err != nil {
p.Logger.Warn("Failed to prepare policies query index", zap.Error(err))
return nil, err
}
roleBasedQueriesIndex := &ri.Definition{
Keys: []ri.Key{
{Field: "roleRef", Sort: ri.Asc},
{Field: "policy.effect.action", Sort: ri.Asc},
},
}
if err := p.DBImp.Repository.CreateIndex(roleBasedQueriesIndex); err != nil {
p.Logger.Warn("Failed to prepare role based query index", zap.Error(err))
return nil, err
}
uniquePolicyConstaint := &ri.Definition{
Keys: []ri.Key{
{Field: "policy.organizationRef", Sort: ri.Asc},
{Field: "roleRef", Sort: ri.Asc},
{Field: "policy.descriptionRef", Sort: ri.Asc},
{Field: "policy.effect.action", Sort: ri.Asc},
{Field: "policy.objectRef", Sort: ri.Asc},
},
Unique: true,
}
if err := p.DBImp.Repository.CreateIndex(uniquePolicyConstaint); err != nil {
p.Logger.Warn("Failed to unique policy assignment index", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -0,0 +1,99 @@
package db
import (
"context"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/repository"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type RolesDBImp struct {
template.DBImp[*nstructures.RoleAssignment]
}
func (db *RolesDBImp) Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
return mutil.GetObjects[nstructures.RoleAssignment](
ctx,
db.Logger,
repository.Query().And(
repository.Filter("role.accountRef", accountRef),
repository.Filter("role.organizationRef", organizationRef),
),
nil,
db.Repository,
)
}
func (db *RolesDBImp) RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
return mutil.GetObjects[nstructures.RoleAssignment](
ctx,
db.Logger,
repository.Query().And(
repository.Filter("role.organizationRef", organizationRef),
),
nil,
db.Repository,
)
}
func (db *RolesDBImp) DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error {
return db.DeleteMany(
ctx,
repository.Query().And(
repository.Filter("role.descriptionRef", roleRef),
),
)
}
func (db *RolesDBImp) RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error {
return db.DeleteMany(
ctx,
repository.Query().And(
repository.Filter("role.accountRef", accountRef),
repository.Filter("role.organizationRef", organizationRef),
repository.Filter("role.descriptionRef", roleRef),
),
)
}
func NewRolesDB(logger mlogger.Logger, db *mongo.Database) (*RolesDBImp, error) {
p := &RolesDBImp{
DBImp: *template.Create[*nstructures.RoleAssignment](logger, "role_assignments", db),
}
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "role.organizationRef", Sort: ri.Asc}},
}); err != nil {
p.Logger.Warn("Failed to prepare venue index", zap.Error(err))
return nil, err
}
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "role.descriptionRef", Sort: ri.Asc}},
}); err != nil {
p.Logger.Warn("Failed to prepare role description index", zap.Error(err))
return nil, err
}
uniqueRoleConstaint := &ri.Definition{
Keys: []ri.Key{
{Field: "role.organizationRef", Sort: ri.Asc},
{Field: "role.accountRef", Sort: ri.Asc},
{Field: "role.descriptionRef", Sort: ri.Asc},
},
Unique: true,
}
if err := p.DBImp.Repository.CreateIndex(uniqueRoleConstaint); err != nil {
p.Logger.Warn("Failed to prepare role assignment index", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -0,0 +1,27 @@
package native
import (
"context"
"github.com/tech/sendico/pkg/auth/internal/native/db"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
type PoliciesDB interface {
template.DB[*nstructures.PolicyAssignment]
// plenty of interfaces for performance reasons
Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error)
PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error)
PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error)
PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error)
Remove(ctx context.Context, policy *model.RolePolicy) error
}
func NewPoliciesDBDB(logger mlogger.Logger, conn *mongo.Database) (PoliciesDB, error) {
return db.NewPoliciesDB(logger, conn)
}

View File

@@ -0,0 +1,24 @@
package native
import (
"context"
"github.com/tech/sendico/pkg/auth/internal/native/db"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)
type RolesDB interface {
template.DB[*nstructures.RoleAssignment]
Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error)
RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error)
RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error
DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error
}
func NewRolesDB(logger mlogger.Logger, conn *mongo.Database) (RolesDB, error) {
return db.NewRolesDB(logger, conn)
}

View File

@@ -0,0 +1,256 @@
package native
import (
"context"
"errors"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type Enforcer struct {
logger mlogger.Logger
pdb PoliciesDB
rdb RolesDB
}
func NewEnforcer(
logger mlogger.Logger,
db *mongo.Database,
) (*Enforcer, error) {
e := &Enforcer{logger: logger.Named("enforcer")}
var err error
if e.pdb, err = NewPoliciesDBDB(e.logger, db); err != nil {
e.logger.Warn("Failed to create permission assignments database", zap.Error(err))
return nil, err
}
if e.rdb, err = NewRolesDB(e.logger, db); err != nil {
e.logger.Warn("Failed to create role assignments database", zap.Error(err))
return nil, err
}
logger.Info("Native enforcer created")
return e, nil
}
// Enforce checks if a user has the specified action permission on an object within a domain.
func (n *Enforcer) Enforce(
ctx context.Context,
permissionRef, accountRef, organizationRef, objectRef primitive.ObjectID,
action model.Action,
) (bool, error) {
roleAssignments, err := n.rdb.Roles(ctx, accountRef, organizationRef)
if errors.Is(err, merrors.ErrNoData) {
n.logger.Debug("No roles defined for account", mzap.ObjRef("account_ref", accountRef))
return false, nil
}
if err != nil {
n.logger.Warn("Failed to fetch roles while checking permissions", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object", objectRef), zap.String("action", string(action)))
return false, err
}
if len(roleAssignments) == 0 {
n.logger.Warn("No roles found for account", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return false, merrors.Internal("No roles found for account " + accountRef.Hex())
}
allowFound := false // Track if any allow is found across roles
for _, roleAssignment := range roleAssignments {
policies, err := n.pdb.PoliciesForPermissionAction(ctx, roleAssignment.DescriptionRef, permissionRef, action)
if err != nil && !errors.Is(err, merrors.ErrNoData) {
n.logger.Warn("Failed to fetch permissions", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return false, err
}
for _, permission := range policies {
if permission.Effect.Effect == model.EffectDeny {
n.logger.Debug("Found denying policy", mzap.ObjRef("account", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return false, nil // Deny takes precedence immediately
}
if permission.Effect.Effect == model.EffectAllow {
n.logger.Debug("Allowing policy found", mzap.ObjRef("account", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
allowFound = true // At least one allow found
} else {
n.logger.Warn("Corrupted policy", mzap.StorableRef(&permission))
return false, merrors.Internal("Corrupted action effect data for permissions entry " + permission.ID.Hex() + ": " + string(permission.Effect.Effect))
}
}
}
// Final decision based on whether any allow was found
if allowFound {
return true, nil // At least one allow and no deny
}
n.logger.Debug("No allowing policy found", mzap.ObjRef("account", accountRef),
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
return false, nil // No allow found, default deny
}
// EnforceBatch checks a users permission for multiple objects at once.
// It returns a map from objectRef -> boolean indicating whether access is granted.
func (n *Enforcer) EnforceBatch(
ctx context.Context,
objectRefs []model.PermissionBoundStorable,
accountRef primitive.ObjectID,
action model.Action,
) (map[primitive.ObjectID]bool, error) {
results := make(map[primitive.ObjectID]bool, len(objectRefs))
// Group objectRefs by organizationRef.
objectsByVenue := make(map[primitive.ObjectID][]model.PermissionBoundStorable)
for _, obj := range objectRefs {
organizationRef := obj.GetOrganizationRef()
objectsByVenue[organizationRef] = append(objectsByVenue[organizationRef], obj)
}
// Process each venue group separately.
for organizationRef, objs := range objectsByVenue {
// 1. Fetch roles once for this account and venue.
roles, err := n.rdb.Roles(ctx, accountRef, organizationRef)
if err != nil {
if errors.Is(err, merrors.ErrNoData) {
n.logger.Debug("No roles defined for account", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
// With no roles, mark all objects in this venue as denied.
for _, obj := range objs {
results[*obj.GetID()] = false
}
// Continue to next venue
continue
}
n.logger.Warn("Failed to fetch roles", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
return nil, err
}
// 2. Extract role description references
var roleRefs []primitive.ObjectID
for _, role := range roles {
roleRefs = append(roleRefs, role.DescriptionRef)
}
// 3. Fetch all policies for these roles and the given action in one call.
allPolicies, err := n.pdb.PoliciesForRoles(ctx, roleRefs, action)
if err != nil {
n.logger.Warn("Failed to fetch policies", zap.Error(err))
return nil, err
}
// 4. Build a lookup map keyed by PermissionRef.
policyMap := make(map[primitive.ObjectID][]nstructures.PolicyAssignment)
for _, policy := range allPolicies {
policyMap[policy.DescriptionRef] = append(policyMap[policy.DescriptionRef], policy)
}
// 5. Evaluate permissions for each object in this venue group.
for _, obj := range objs {
permRef := obj.GetPermissionRef()
allow := false
if policies, ok := policyMap[permRef]; ok {
for _, policy := range policies {
// Deny takes precedence.
if policy.Effect.Effect == model.EffectDeny {
allow = false
break
}
if policy.Effect.Effect == model.EffectAllow {
allow = true
// Continue checking in case a deny exists among policies.
} else {
// should never get here
return nil, merrors.Internal("Corrupted permissions effect in policy assignment '" + policy.GetID().Hex() + "': " + string(policy.Effect.Effect))
}
}
}
results[*obj.GetID()] = allow
}
}
return results, nil
}
// GetRoles retrieves all roles assigned to the user within the domain.
func (n *Enforcer) GetRoles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, error) {
n.logger.Debug("Fetching roles for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
ra, err := n.rdb.Roles(ctx, accountRef, organizationRef)
if errors.Is(err, merrors.ErrNoData) {
n.logger.Debug("No roles assigned to user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
return []model.Role{}, nil
}
if err != nil {
n.logger.Warn("Failed to fetch roles", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
return nil, err
}
roles := make([]model.Role, len(ra))
for i, roleAssignement := range ra {
roles[i] = roleAssignement.Role
}
n.logger.Debug("Fetched roles", zap.Int("roles_count", len(roles)))
return roles, nil
}
func (n *Enforcer) Reload() error {
n.logger.Info("Policies reloaded") // do nothing actually
return nil
}
// GetPermissions retrieves all effective policies for the user within the domain.
func (n *Enforcer) GetPermissions(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
n.logger.Debug("Fetching policies for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
roles, err := n.GetRoles(ctx, accountRef, organizationRef)
if err != nil {
n.logger.Warn("Failed to get roles", zap.Error(err))
return nil, nil, err
}
uniquePermissions := make(map[primitive.ObjectID]model.Permission)
for _, role := range roles {
perms, err := n.pdb.PoliciesForRole(ctx, role.DescriptionRef)
if err != nil {
n.logger.Warn("Failed to get policies for role", zap.Error(err), mzap.ObjRef("role_ref", role.DescriptionRef))
continue
}
n.logger.Debug("Policies fetched for role", mzap.ObjRef("role_ref", role.DescriptionRef), zap.Int("count", len(perms)))
for _, p := range perms {
uniquePermissions[*p.GetID()] = model.Permission{
RolePolicy: model.RolePolicy{
Policy: p.Policy,
RoleDescriptionRef: p.RoleRef,
},
AccountRef: accountRef,
}
}
}
permissionsSlice := make([]model.Permission, 0, len(uniquePermissions))
for _, permission := range uniquePermissions {
permissionsSlice = append(permissionsSlice, permission)
}
n.logger.Debug("Policies fetched successfully", zap.Int("count", len(permissionsSlice)))
return roles, permissionsSlice, nil
}

View File

@@ -0,0 +1,747 @@
package native
import (
"context"
"errors"
"testing"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/merrors"
factory "github.com/tech/sendico/pkg/mlogger/factory"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Mock implementations for testing
type MockPoliciesDB struct {
mock.Mock
}
func (m *MockPoliciesDB) PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, roleRef, permissionRef, action)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, roleRef)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, roleRefs, action)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, object, action)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) Remove(ctx context.Context, policy *model.RolePolicy) error {
args := m.Called(ctx, policy)
return args.Error(0)
}
// Template DB methods - implement as needed for testing
func (m *MockPoliciesDB) Create(ctx context.Context, assignment *nstructures.PolicyAssignment) error {
args := m.Called(ctx, assignment)
return args.Error(0)
}
func (m *MockPoliciesDB) Get(ctx context.Context, id primitive.ObjectID, assignment *nstructures.PolicyAssignment) error {
args := m.Called(ctx, id, assignment)
return args.Error(0)
}
func (m *MockPoliciesDB) Update(ctx context.Context, assignment *nstructures.PolicyAssignment) error {
args := m.Called(ctx, assignment)
return args.Error(0)
}
func (m *MockPoliciesDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
args := m.Called(ctx, objectRef, patch)
return args.Error(0)
}
func (m *MockPoliciesDB) Delete(ctx context.Context, id primitive.ObjectID) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *MockPoliciesDB) DeleteMany(ctx context.Context, query builder.Query) error {
args := m.Called(ctx, query)
return args.Error(0)
}
func (m *MockPoliciesDB) ListPermissionBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, accountRef, organizationRef)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) ListIDs(ctx context.Context, query interface{}) ([]primitive.ObjectID, error) {
args := m.Called(ctx, query)
return args.Get(0).([]primitive.ObjectID), args.Error(1)
}
func (m *MockPoliciesDB) FindOne(ctx context.Context, query builder.Query, assignment *nstructures.PolicyAssignment) error {
args := m.Called(ctx, query, assignment)
return args.Error(0)
}
func (m *MockPoliciesDB) List(ctx context.Context, query builder.Query) ([]nstructures.PolicyAssignment, error) {
args := m.Called(ctx, query)
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
}
func (m *MockPoliciesDB) Name() string {
return "mock_policies"
}
func (m *MockPoliciesDB) DeleteCascade(ctx context.Context, id primitive.ObjectID) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *MockPoliciesDB) InsertMany(ctx context.Context, objects []*nstructures.PolicyAssignment) error {
args := m.Called(ctx, objects)
return args.Error(0)
}
type MockRolesDB struct {
mock.Mock
}
func (m *MockRolesDB) Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
args := m.Called(ctx, accountRef, organizationRef)
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
}
func (m *MockRolesDB) RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
args := m.Called(ctx, organizationRef)
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
}
func (m *MockRolesDB) RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error {
args := m.Called(ctx, roleRef, organizationRef, accountRef)
return args.Error(0)
}
func (m *MockRolesDB) DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error {
args := m.Called(ctx, roleRef)
return args.Error(0)
}
// Template DB methods - implement as needed for testing
func (m *MockRolesDB) Create(ctx context.Context, assignment *nstructures.RoleAssignment) error {
args := m.Called(ctx, assignment)
return args.Error(0)
}
func (m *MockRolesDB) Get(ctx context.Context, id primitive.ObjectID, assignment *nstructures.RoleAssignment) error {
args := m.Called(ctx, id, assignment)
return args.Error(0)
}
func (m *MockRolesDB) Update(ctx context.Context, assignment *nstructures.RoleAssignment) error {
args := m.Called(ctx, assignment)
return args.Error(0)
}
func (m *MockRolesDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
args := m.Called(ctx, objectRef, patch)
return args.Error(0)
}
func (m *MockRolesDB) Delete(ctx context.Context, id primitive.ObjectID) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *MockRolesDB) DeleteMany(ctx context.Context, query builder.Query) error {
args := m.Called(ctx, query)
return args.Error(0)
}
func (m *MockRolesDB) ListPermissionBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
args := m.Called(ctx, accountRef, organizationRef)
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
}
func (m *MockRolesDB) ListIDs(ctx context.Context, query interface{}) ([]primitive.ObjectID, error) {
args := m.Called(ctx, query)
return args.Get(0).([]primitive.ObjectID), args.Error(1)
}
func (m *MockRolesDB) FindOne(ctx context.Context, query builder.Query, assignment *nstructures.RoleAssignment) error {
args := m.Called(ctx, query, assignment)
return args.Error(0)
}
func (m *MockRolesDB) List(ctx context.Context, query builder.Query) ([]nstructures.RoleAssignment, error) {
args := m.Called(ctx, query)
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
}
func (m *MockRolesDB) Name() string {
return "mock_roles"
}
func (m *MockRolesDB) DeleteCascade(ctx context.Context, id primitive.ObjectID) error {
args := m.Called(ctx, id)
return args.Error(0)
}
func (m *MockRolesDB) InsertMany(ctx context.Context, objects []*nstructures.RoleAssignment) error {
args := m.Called(ctx, objects)
return args.Error(0)
}
// Test helper functions
func createTestObjectID() primitive.ObjectID {
return primitive.NewObjectID()
}
func createTestRoleAssignment(roleRef, accountRef, organizationRef primitive.ObjectID) nstructures.RoleAssignment {
return nstructures.RoleAssignment{
Role: model.Role{
AccountRef: accountRef,
DescriptionRef: roleRef,
OrganizationRef: organizationRef,
},
}
}
func createTestPolicyAssignment(roleRef primitive.ObjectID, action model.Action, effect model.Effect, organizationRef, descriptionRef primitive.ObjectID, objectRef *primitive.ObjectID) nstructures.PolicyAssignment {
return nstructures.PolicyAssignment{
Policy: model.Policy{
OrganizationRef: organizationRef,
DescriptionRef: descriptionRef,
ObjectRef: objectRef,
Effect: model.ActionEffect{
Action: action,
Effect: effect,
},
},
RoleRef: roleRef,
}
}
func createTestEnforcer(pdb PoliciesDB, rdb RolesDB) *Enforcer {
logger := factory.NewLogger(true)
enforcer := &Enforcer{
logger: logger.Named("test"),
pdb: pdb,
rdb: rdb,
}
return enforcer
}
func TestEnforcer_Enforce(t *testing.T) {
ctx := context.Background()
// Test data
accountRef := createTestObjectID()
organizationRef := createTestObjectID()
permissionRef := createTestObjectID()
objectRef := createTestObjectID()
roleRef := createTestObjectID()
t.Run("Allow_SingleRole_SinglePolicy", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock policy assignment with ALLOW effect
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
// Create enforcer
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.True(t, allowed)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("Deny_SingleRole_SinglePolicy", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock policy assignment with DENY effect
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("DenyTakesPrecedence_MultipleRoles", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
role1Ref := createTestObjectID()
role2Ref := createTestObjectID()
// Mock multiple role assignments
roleAssignment1 := createTestRoleAssignment(role1Ref, accountRef, organizationRef)
roleAssignment2 := createTestRoleAssignment(role2Ref, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment1, roleAssignment2}, nil)
// First role has ALLOW policy
allowPolicy := createTestPolicyAssignment(role1Ref, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, role1Ref, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{allowPolicy}, nil)
// Second role has DENY policy - should take precedence
denyPolicy := createTestPolicyAssignment(role2Ref, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, role2Ref, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{denyPolicy}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify - DENY should take precedence
require.NoError(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("NoRoles_ReturnsFalse", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock no roles found
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
})
t.Run("EmptyRoles_ReturnsError", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock empty roles list (not NoData error)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.Error(t, err)
assert.False(t, allowed)
assert.Contains(t, err.Error(), "No roles found for account")
mockRDB.AssertExpectations(t)
})
t.Run("DatabaseError_RolesDB", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock database error
dbError := errors.New("database connection failed")
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.Error(t, err)
assert.False(t, allowed)
assert.Equal(t, dbError, err)
mockRDB.AssertExpectations(t)
})
t.Run("DatabaseError_PoliciesDB", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock database error in policies
dbError := errors.New("policies database error")
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{}, dbError)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.Error(t, err)
assert.False(t, allowed)
assert.Equal(t, dbError, err)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("NoPolicies_ReturnsFalse", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock no policies found
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{}, merrors.ErrNoData)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("CorruptedPolicy_ReturnsError", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock corrupted policy with invalid effect
corruptedPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, "invalid_effect", organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{corruptedPolicy}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify
require.Error(t, err)
assert.False(t, allowed)
assert.Contains(t, err.Error(), "Corrupted action effect data")
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
}
// Mock implementation for PermissionBoundStorable
type MockPermissionBoundStorable struct {
id primitive.ObjectID
permissionRef primitive.ObjectID
organizationRef primitive.ObjectID
}
func (m *MockPermissionBoundStorable) GetID() *primitive.ObjectID {
return &m.id
}
func (m *MockPermissionBoundStorable) GetPermissionRef() primitive.ObjectID {
return m.permissionRef
}
func (m *MockPermissionBoundStorable) GetOrganizationRef() primitive.ObjectID {
return m.organizationRef
}
func (m *MockPermissionBoundStorable) Collection() string {
return "test_objects"
}
func (m *MockPermissionBoundStorable) SetID(objID primitive.ObjectID) {
m.id = objID
}
func (m *MockPermissionBoundStorable) Update() {
// Do nothing for mock
}
func (m *MockPermissionBoundStorable) SetPermissionRef(permissionRef primitive.ObjectID) {
m.permissionRef = permissionRef
}
func (m *MockPermissionBoundStorable) SetOrganizationRef(organizationRef primitive.ObjectID) {
m.organizationRef = organizationRef
}
func (m *MockPermissionBoundStorable) IsArchived() bool {
return false // Default to not archived for testing
}
func (m *MockPermissionBoundStorable) SetArchived(archived bool) {
// No-op for testing
}
func TestEnforcer_EnforceBatch(t *testing.T) {
ctx := context.Background()
// Test data
accountRef := createTestObjectID()
organizationRef := createTestObjectID()
permissionRef := createTestObjectID()
roleRef := createTestObjectID()
// Create test objects
object1 := &MockPermissionBoundStorable{
id: createTestObjectID(),
permissionRef: permissionRef,
organizationRef: organizationRef,
}
object2 := &MockPermissionBoundStorable{
id: createTestObjectID(),
permissionRef: permissionRef,
organizationRef: organizationRef,
}
t.Run("BatchEnforce_MultipleObjects_SameVenue", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock policy assignment with ALLOW effect
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, nil)
mockPDB.On("PoliciesForRoles", ctx, []primitive.ObjectID{roleRef}, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute batch enforcement
objects := []model.PermissionBoundStorable{object1, object2}
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.Len(t, results, 2)
assert.True(t, results[object1.id])
assert.True(t, results[object2.id])
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("BatchEnforce_NoRoles_AllObjectsDenied", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock no roles found
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute batch enforcement
objects := []model.PermissionBoundStorable{object1, object2}
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
// Verify
require.NoError(t, err)
assert.Len(t, results, 2)
assert.False(t, results[object1.id])
assert.False(t, results[object2.id])
mockRDB.AssertExpectations(t)
})
t.Run("BatchEnforce_DatabaseError", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock database error
dbError := errors.New("database connection failed")
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute batch enforcement
objects := []model.PermissionBoundStorable{object1, object2}
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
// Verify
require.Error(t, err)
assert.Nil(t, results)
assert.Equal(t, dbError, err)
mockRDB.AssertExpectations(t)
})
}
func TestEnforcer_GetRoles(t *testing.T) {
ctx := context.Background()
// Test data
accountRef := createTestObjectID()
organizationRef := createTestObjectID()
roleRef := createTestObjectID()
t.Run("GetRoles_Success", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
roles, err := enforcer.GetRoles(ctx, accountRef, organizationRef)
// Verify
require.NoError(t, err)
assert.Len(t, roles, 1)
assert.Equal(t, roleRef, roles[0].DescriptionRef)
mockRDB.AssertExpectations(t)
})
t.Run("GetRoles_NoRoles", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock no roles found
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
roles, err := enforcer.GetRoles(ctx, accountRef, organizationRef)
// Verify
require.NoError(t, err)
assert.Len(t, roles, 0)
mockRDB.AssertExpectations(t)
})
}
func TestEnforcer_GetPermissions(t *testing.T) {
ctx := context.Background()
// Test data
accountRef := createTestObjectID()
organizationRef := createTestObjectID()
roleRef := createTestObjectID()
t.Run("GetPermissions_Success", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock policy assignment
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, createTestObjectID(), nil)
mockPDB.On("PoliciesForRole", ctx, roleRef).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
roles, permissions, err := enforcer.GetPermissions(ctx, accountRef, organizationRef)
// Verify
require.NoError(t, err)
assert.Len(t, roles, 1)
assert.Len(t, permissions, 1)
assert.Equal(t, accountRef, permissions[0].AccountRef)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
}
// Security-focused test scenarios
func TestEnforcer_SecurityScenarios(t *testing.T) {
ctx := context.Background()
// Test data
accountRef := createTestObjectID()
organizationRef := createTestObjectID()
permissionRef := createTestObjectID()
objectRef := createTestObjectID()
roleRef := createTestObjectID()
t.Run("Security_DenyAlwaysWins", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock role assignment
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
// Mock multiple policies: both ALLOW and DENY
allowPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
denyPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{allowPolicy, denyPolicy}, nil)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify - DENY should always win
require.NoError(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
mockPDB.AssertExpectations(t)
})
t.Run("Security_InvalidObjectID", func(t *testing.T) {
mockPDB := &MockPoliciesDB{}
mockRDB := &MockRolesDB{}
// Mock database error for invalid ObjectID
dbError := errors.New("invalid ObjectID")
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
enforcer := createTestEnforcer(mockPDB, mockRDB)
// Execute with invalid ObjectID
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
// Verify - should fail securely
require.Error(t, err)
assert.False(t, allowed)
mockRDB.AssertExpectations(t)
})
}
// Note: This test provides comprehensive coverage of the native enforcer including:
// 1. Basic enforcement logic with deny-takes-precedence
// 2. Batch operations for performance
// 3. Role and permission retrieval
// 4. Security scenarios and edge cases
// 5. Error handling and database failures
// 6. All critical security paths are tested

View File

@@ -0,0 +1,51 @@
package native
import (
"context"
"github.com/tech/sendico/pkg/auth/management"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.uber.org/zap"
)
// NativeManager implements the auth.Manager interface by aggregating Role and Permission managers.
type NativeManager struct {
logger mlogger.Logger
roleManager management.Role
permManager management.Permission
}
// NewManager creates a new CasbinManager with specified domains and role-domain mappings.
func NewManager(
l mlogger.Logger,
pdb policy.DB,
rdb role.DB,
enforcer *Enforcer,
) (*NativeManager, error) {
logger := l.Named("manager")
var pdesc model.PolicyDescription
if err := pdb.GetBuiltInPolicy(context.Background(), "roles", &pdesc); err != nil {
logger.Warn("Failed to fetch roles permission reference", zap.Error(err))
return nil, err
}
return &NativeManager{
logger: logger,
roleManager: NewRoleManager(logger, enforcer, pdesc.ID, rdb),
permManager: NewPermissionManager(logger, enforcer),
}, nil
}
// Permission returns the Permission manager.
func (m *NativeManager) Permission() management.Permission {
return m.permManager
}
// Role returns the Role manager.
func (m *NativeManager) Role() management.Role {
return m.roleManager
}

Binary file not shown.

View File

@@ -0,0 +1,17 @@
package nstructures
import (
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type PolicyAssignment struct {
storable.Base `bson:",inline" json:",inline"`
model.Policy `bson:"policy" json:"policy"`
RoleRef primitive.ObjectID `bson:"roleRef" json:"roleRef"`
}
func (*PolicyAssignment) Collection() string {
return "permission_assignments"
}

View File

@@ -0,0 +1,15 @@
package nstructures
import (
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/model"
)
type RoleAssignment struct {
storable.Base `bson:",inline" json:",inline"`
model.Role `bson:"role" json:"role"`
}
func (*RoleAssignment) Collection() string {
return "role_assignments"
}

View File

@@ -0,0 +1,101 @@
package native
import (
"context"
"errors"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// PermissionManager manages permissions using Casbin.
type PermissionManager struct {
logger mlogger.Logger
enforcer *Enforcer
}
// GrantToRole adds a permission to a role in Casbin.
func (m *PermissionManager) GrantToRole(ctx context.Context, policy *model.RolePolicy) error {
objRef := "any"
if (policy.ObjectRef != nil) && (*policy.ObjectRef != primitive.NilObjectID) {
objRef = policy.ObjectRef.Hex()
}
m.logger.Debug("Granting permission to role", mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)),
)
assignment := nstructures.PolicyAssignment{
Policy: policy.Policy,
RoleRef: policy.RoleDescriptionRef,
}
if err := m.enforcer.pdb.Create(ctx, &assignment); err != nil {
m.logger.Warn("Failed to grant policy", zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)))
return err
}
return nil
}
// RevokeFromRole removes a permission from a role in Casbin.
func (m *PermissionManager) RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error {
objRef := "*"
if policy.ObjectRef != nil {
objRef = policy.ObjectRef.Hex()
}
m.logger.Debug("Revoking permission from role", mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)),
)
if err := m.enforcer.pdb.Remove(ctx, policy); err != nil {
m.logger.Warn("Failed to revoke policy", zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)))
return err
}
return nil
}
// GetPolicies retrieves all policies for a specific role.
func (m *PermissionManager) GetPolicies(
ctx context.Context,
roleRef primitive.ObjectID,
) ([]model.RolePolicy, error) {
m.logger.Debug("Fetching policies for role", mzap.ObjRef("role_ref", roleRef))
assinments, err := m.enforcer.pdb.PoliciesForRole(ctx, roleRef)
if errors.Is(err, merrors.ErrNoData) {
m.logger.Debug("No policies found", mzap.ObjRef("role_ref", roleRef))
return []model.RolePolicy{}, nil
}
policies := make([]model.RolePolicy, len(assinments))
for i, assinment := range assinments {
policies[i] = model.RolePolicy{
Policy: assinment.Policy,
RoleDescriptionRef: assinment.RoleRef,
}
}
m.logger.Debug("Policies fetched successfully", mzap.ObjRef("role_ref", roleRef), zap.Int("count", len(policies)))
return policies, nil
}
// Save persists changes to the Casbin policy store.
func (m *PermissionManager) Save() error {
m.logger.Info("Policies successfully saved") // do nothing
return nil
}
func NewPermissionManager(logger mlogger.Logger, enforcer *Enforcer) *PermissionManager {
return &PermissionManager{
logger: logger.Named("permission"),
enforcer: enforcer,
}
}

View File

@@ -0,0 +1,142 @@
package native
import (
"context"
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// RoleManager manages roles using Casbin.
type RoleManager struct {
logger mlogger.Logger
enforcer *Enforcer
rdb role.DB
rolePermissionRef primitive.ObjectID
}
// NewRoleManager creates a new RoleManager.
func NewRoleManager(logger mlogger.Logger, enforcer *Enforcer, rolePermissionRef primitive.ObjectID, rdb role.DB) *RoleManager {
return &RoleManager{
logger: logger.Named("role"),
enforcer: enforcer,
rdb: rdb,
rolePermissionRef: rolePermissionRef,
}
}
// validateObjectIDs ensures that all provided ObjectIDs are non-zero.
func (rm *RoleManager) validateObjectIDs(ids ...primitive.ObjectID) error {
for _, id := range ids {
if id.IsZero() {
return merrors.InvalidArgument("Object references cannot be zero")
}
}
return nil
}
// fetchRolesFromPolicies retrieves and converts policies to roles.
func (rm *RoleManager) fetchRolesFromPolicies(roles []nstructures.RoleAssignment, organizationRef primitive.ObjectID) []model.RoleDescription {
result := make([]model.RoleDescription, len(roles))
for i, role := range roles {
result[i] = model.RoleDescription{
Base: storable.Base{ID: *role.GetID()},
OrganizationRef: organizationRef,
}
}
return result
}
// Create creates a new role in an organization.
func (rm *RoleManager) Create(ctx context.Context, organizationRef primitive.ObjectID, description *model.Describable) (*model.RoleDescription, error) {
if err := rm.validateObjectIDs(organizationRef); err != nil {
return nil, err
}
role := &model.RoleDescription{
OrganizationRef: organizationRef,
Describable: *description,
}
if err := rm.rdb.Create(ctx, role); err != nil {
rm.logger.Warn("Failed to create role", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return nil, err
}
rm.logger.Info("Role created successfully", mzap.StorableRef(role), mzap.ObjRef("organization_ref", organizationRef))
return role, nil
}
// Assign assigns a role to a user in the given organization.
func (rm *RoleManager) Assign(ctx context.Context, role *model.Role) error {
if err := rm.validateObjectIDs(role.DescriptionRef, role.AccountRef, role.OrganizationRef); err != nil {
return err
}
assogment := nstructures.RoleAssignment{Role: *role}
err := rm.enforcer.rdb.Create(ctx, &assogment)
return rm.logPolicyResult("assign", err == nil, err, role.DescriptionRef, role.AccountRef, role.OrganizationRef)
}
// Delete removes a role entirely and cleans up associated Casbin policies.
func (rm *RoleManager) Delete(ctx context.Context, roleRef primitive.ObjectID) error {
if err := rm.validateObjectIDs(roleRef); err != nil {
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
return err
}
if err := rm.rdb.Delete(ctx, roleRef); err != nil {
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
return err
}
if err := rm.enforcer.rdb.DeleteRole(ctx, roleRef); err != nil {
rm.logger.Warn("Failed to remove role", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
return err
}
rm.logger.Info("Role deleted successfully along with associated policies", mzap.ObjRef("role_ref", roleRef))
return nil
}
// Revoke removes a role from a user.
func (rm *RoleManager) Revoke(ctx context.Context, roleRef, accountRef, organizationRef primitive.ObjectID) error {
if err := rm.validateObjectIDs(roleRef, accountRef, organizationRef); err != nil {
return err
}
err := rm.enforcer.rdb.RemoveRole(ctx, roleRef, organizationRef, accountRef)
return rm.logPolicyResult("revoke", err == nil, err, roleRef, accountRef, organizationRef)
}
// logPolicyResult logs results for Assign and Revoke.
func (rm *RoleManager) logPolicyResult(action string, result bool, err error, roleRef, accountRef, organizationRef primitive.ObjectID) error {
if err != nil {
rm.logger.Warn("Failed to "+action+" role", zap.Error(err), mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
return err
}
msg := "Role " + action + "ed successfully"
if !result {
msg = "Role already " + action + "ed"
}
rm.logger.Info(msg, mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
return nil
}
// List retrieves all roles in an organization or all roles if organizationRef is zero.
func (rm *RoleManager) List(ctx context.Context, organizationRef primitive.ObjectID) ([]model.RoleDescription, error) {
roles4Venues, err := rm.enforcer.rdb.RolesForVenue(ctx, organizationRef)
if err != nil {
rm.logger.Warn("Failed to fetch grouping policies", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return nil, err
}
roles := rm.fetchRolesFromPolicies(roles4Venues, organizationRef)
rm.logger.Info("Retrieved roles for organization", mzap.ObjRef("organization_ref", organizationRef), zap.Int("count", len(roles)))
return roles, nil
}

View File

@@ -0,0 +1,27 @@
package management
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type Permission interface {
// Grant a permission to a role with an optional object scope and specified effect.
// Use primitive.NilObjectID for 'any' objectRef.
GrantToRole(ctx context.Context, policy *model.RolePolicy) error
// Revoke a permission from a role with an optional object scope and specified effect.
// Use primitive.NilObjectID for 'any' objectRef.
RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error
// Retrieve all policies assigned to a specific role, including scope and effects.
GetPolicies(
ctx context.Context,
roleRef primitive.ObjectID,
) ([]model.RolePolicy, error)
// Persist any changes made to permissions.
Save() error
}

View File

@@ -0,0 +1,41 @@
package management
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type Role interface {
// Create a new role in an organization (returns the created Role with its ID).
Create(
ctx context.Context,
orgRef primitive.ObjectID,
description *model.Describable,
) (*model.RoleDescription, error)
// Delete a role entirely. This will cascade and remove all associated
Delete(
ctx context.Context,
roleRef primitive.ObjectID,
) error
// Assign a role to a user in a specific organization.
Assign(
ctx context.Context,
role *model.Role,
) error
// Revoke a role from a user in a specific organization.
Revoke(
ctx context.Context,
roleRef, accountRef, orgRef primitive.ObjectID,
) error
// List all roles in an organization or globally if orgRef is primitive.NilObjectID.
List(
ctx context.Context,
orgRef primitive.ObjectID,
) ([]model.RoleDescription, error)
}

15
api/pkg/auth/manager.go Normal file
View File

@@ -0,0 +1,15 @@
package auth
import (
"github.com/tech/sendico/pkg/auth/management"
)
// Manager provides access to domain-aware Permission and Role managers.
type Manager interface {
// Permission returns a manager that handles permission grants/revokes
// for a specific resource type. (You might add domainRef here if desired.)
Permission() management.Permission
// Role returns the domain-aware Role manager.
Role() management.Role
}

14
api/pkg/auth/provider.go Normal file
View File

@@ -0,0 +1,14 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
)
type Provider interface {
Enforcer() Enforcer
Manager() Manager
GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error)
}

43
api/pkg/auth/taggable.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// TaggableDB implements tag operations with permission checking
type TaggableDB[T model.PermissionBoundStorable] interface {
// AddTag adds a tag to an entity with permission checking
AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error
// RemoveTagd removes a tags from the collection using organizationRef with permission checking
RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error
// RemoveTag removes a tag from an entity with permission checking
RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error
// AddTags adds multiple tags to an entity with permission checking
AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error
// SetTags sets the tags for an entity with permission checking
SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error
// RemoveAllTags removes all tags from an entity with permission checking
RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
// GetTags gets the tags for an entity with permission checking
GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error)
// HasTag checks if an entity has a specific tag with permission checking
HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error)
// FindByTag finds all entities that have a specific tag with permission checking
FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]T, error)
// FindByTags finds all entities that have any of the specified tags with permission checking
FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]T, error)
}
// NewTaggableDBImp creates a new auth.TaggableDB instance
func NewTaggableDB[T model.PermissionBoundStorable](
dbImp *template.DBImp[T],
enforcer Enforcer,
createEmpty func() T,
getTaggable func(T) *model.Taggable,
) TaggableDB[T] {
return newTaggableDBImp(dbImp, enforcer, createEmpty, getTaggable)
}

302
api/pkg/auth/taggableimp.go Normal file
View File

@@ -0,0 +1,302 @@
package auth
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// taggableDBImp implements tag operations with permission checking
type taggableDBImp[T model.PermissionBoundStorable] struct {
dbImp *template.DBImp[T]
logger mlogger.Logger
enforcer Enforcer
createEmpty func() T
getTaggable func(T) *model.Taggable
}
func newTaggableDBImp[T model.PermissionBoundStorable](
dbImp *template.DBImp[T],
enforcer Enforcer,
createEmpty func() T,
getTaggable func(T) *model.Taggable,
) TaggableDB[T] {
return &taggableDBImp[T]{
dbImp: dbImp,
logger: dbImp.Logger.Named("taggable"),
enforcer: enforcer,
createEmpty: createEmpty,
getTaggable: getTaggable,
}
}
func (db *taggableDBImp[T]) AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
return err
}
// Add the tag
patch := repository.Patch().AddToSet(repository.TagRefsField(), tagRef)
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to add tag to object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return err
}
db.logger.Debug("Successfully added tag to object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return nil
}
func (db *taggableDBImp[T]) removeTag(ctx context.Context, accountRef, targetRef, tagRef primitive.ObjectID, query builder.Query) error {
// Check permissions using enforceObject helper
if err := enforceObject(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, query); err != nil {
db.logger.Debug("Error enforcing permissions for removing tag", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef))
return err
}
// Remove the tag
patch := repository.Patch().Pull(repository.TagRefsField(), tagRef)
patched, err := db.dbImp.PatchMany(ctx, query, patch)
if err != nil {
db.logger.Warn("Failed to remove tag from object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef))
return err
}
db.logger.Debug("Successfully removed tag from object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef), zap.Int("patched_count", patched))
return nil
}
func (db *taggableDBImp[T]) RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
return db.removeTag(ctx, accountRef, primitive.NilObjectID, tagRef, repository.OrgFilter(organizationRef))
}
func (db *taggableDBImp[T]) RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
return db.removeTag(ctx, accountRef, objectRef, tagRef, repository.IDFilter(objectRef))
}
// AddTags adds multiple tags to an entity with permission checking
func (db *taggableDBImp[T]) AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
return err
}
// Add the tags one by one using $addToSet to avoid duplicates
for _, tagRef := range tagRefs {
patch := repository.Patch().AddToSet(repository.TagRefsField(), tagRef)
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to add tag to object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return err
}
}
db.logger.Debug("Successfully added tags to object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(tagRefs)))
return nil
}
// SetTags sets the tags for an entity with permission checking
func (db *taggableDBImp[T]) SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
return err
}
// Set the tags
patch := repository.Patch().Set(repository.TagRefsField(), tagRefs)
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to set tags for object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
db.logger.Debug("Successfully set tags for object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(tagRefs)))
return nil
}
// RemoveAllTags removes all tags from an entity with permission checking
func (db *taggableDBImp[T]) RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
return err
}
// Remove all tags by setting to empty array
patch := repository.Patch().Set(repository.TagRefsField(), []primitive.ObjectID{})
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to remove all tags from object", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return err
}
db.logger.Debug("Successfully removed all tags from object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef))
return nil
}
// GetTags gets the tags for an entity with permission checking
func (db *taggableDBImp[T]) GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error) {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
return nil, err
}
// Get the object and extract tags
obj := db.createEmpty()
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for retrieving tags", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
return nil, err
}
// Get the tags
taggable := db.getTaggable(obj)
db.logger.Debug("Successfully retrieved tags for object", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(taggable.TagRefs)))
return taggable.TagRefs, nil
}
// HasTag checks if an entity has a specific tag with permission checking
func (db *taggableDBImp[T]) HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error) {
// Check permissions using enforceObject helper
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
return false, err
}
// Get the object and check if the tag exists
obj := db.createEmpty()
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for checking tag", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return false, err
}
// Check if the tag exists
taggable := db.getTaggable(obj)
for _, existingTag := range taggable.TagRefs {
if existingTag == tagRef {
db.logger.Debug("Object has tag", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return true, nil
}
}
db.logger.Debug("Object does not have tag", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
return false, nil
}
// FindByTag finds all entities that have a specific tag with permission checking
func (db *taggableDBImp[T]) FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]T, error) {
// Create filter to find objects with the tag
filter := repository.Filter(model.TagRefsField, tagRef)
// Get all objects with the tag using ListPermissionBound
objects, err := db.dbImp.ListPermissionBound(ctx, filter)
if err != nil {
db.logger.Warn("Failed to get objects with tag", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("tag_ref", tagRef))
return nil, err
}
// Check permissions for all objects using EnforceBatch
db.logger.Debug("Checking permissions for objects with tag", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("tag_ref", tagRef), zap.Int("object_count", len(objects)))
permissions, err := db.enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
if err != nil {
db.logger.Warn("Failed to check permissions for objects with tag", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("tag_ref", tagRef), zap.Int("object_count", len(objects)))
return nil, merrors.Internal("failed to check permissions for objects with tag")
}
// Filter objects based on permissions and decode them
var results []T
for _, obj := range objects {
objID := *obj.GetID()
if hasPermission, exists := permissions[objID]; exists && hasPermission {
// Decode the object
decodedObj := db.createEmpty()
if err := db.dbImp.Get(ctx, objID, decodedObj); err != nil {
db.logger.Warn("Failed to decode object with tag", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objID), mzap.ObjRef("tag_ref", tagRef))
continue
}
results = append(results, decodedObj)
}
}
db.logger.Debug("Successfully found objects with tag", mzap.ObjRef("account_ref", accountRef),
mzap.ObjRef("tag_ref", tagRef), zap.Int("total_objects", len(objects)), zap.Int("accessible_objects", len(results)))
return results, nil
}
// FindByTags finds all entities that have any of the specified tags with permission checking
func (db *taggableDBImp[T]) FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]T, error) {
if len(tagRefs) == 0 {
return []T{}, nil
}
// Convert []primitive.ObjectID to []any for the In method
values := make([]any, len(tagRefs))
for i, tagRef := range tagRefs {
values[i] = tagRef
}
// Create filter to find objects with any of the tags
filter := repository.Query().In(repository.TagRefsField(), values...)
// Get all objects with any of the tags using ListPermissionBound
objects, err := db.dbImp.ListPermissionBound(ctx, filter)
if err != nil {
db.logger.Warn("Failed to get objects with tags", zap.Error(err),
mzap.ObjRef("account_ref", accountRef))
return nil, err
}
// Check permissions for all objects using EnforceBatch
db.logger.Debug("Checking permissions for objects with tags", mzap.ObjRef("account_ref", accountRef),
zap.Int("object_count", len(objects)), zap.Int("tag_count", len(tagRefs)))
permissions, err := db.enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
if err != nil {
db.logger.Warn("Failed to check permissions for objects with tags", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.Int("object_count", len(objects)))
return nil, merrors.Internal("failed to check permissions for objects with tags")
}
// Filter objects based on permissions and decode them
var results []T
for _, obj := range objects {
objID := *obj.GetID()
if hasPermission, exists := permissions[objID]; exists && hasPermission {
// Decode the object
decodedObj := db.createEmpty()
if err := db.dbImp.Get(ctx, objID, decodedObj); err != nil {
db.logger.Warn("Failed to decode object with tags", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objID))
continue
}
results = append(results, decodedObj)
}
}
db.logger.Debug("Successfully found objects with tags", mzap.ObjRef("account_ref", accountRef),
zap.Int("total_objects", len(objects)), zap.Int("accessible_objects", len(results)), zap.Int("tag_count", len(tagRefs)))
return results, nil
}

21
api/pkg/clock/clock.go Normal file
View File

@@ -0,0 +1,21 @@
package clock
import "time"
// Clock exposes basic time operations, primarily for test overrides.
type Clock interface {
Now() time.Time
}
// System implements Clock using the system wall clock.
type System struct{}
// Now returns the current UTC time.
func (System) Now() time.Time {
return time.Now().UTC()
}
// NewSystem returns a system-backed clock instance.
func NewSystem() Clock {
return System{}
}

17
api/pkg/db/account/account.go Executable file
View File

@@ -0,0 +1,17 @@
package account
import (
"context"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// DB is the interface which must be implemented by all db drivers
type DB interface {
template.DB[*model.Account]
GetByEmail(ctx context.Context, email string) (*model.Account, error)
GetByToken(ctx context.Context, email string) (*model.Account, error)
GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error)
}

11
api/pkg/db/config.go Normal file
View File

@@ -0,0 +1,11 @@
package db
import "github.com/tech/sendico/pkg/model"
type DBDriver string
const (
Mongo DBDriver = "mongodb"
)
type Config = model.DriverConfig[DBDriver]

65
api/pkg/db/connection.go Normal file
View File

@@ -0,0 +1,65 @@
package db
import (
"context"
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
mongoDriver "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
// Connection represents a low-level database connection lifecycle.
type Connection interface {
Disconnect(ctx context.Context) error
Ping(ctx context.Context) error
}
// MongoConnection provides direct access to the underlying mongo client.
type MongoConnection struct {
client *mongoDriver.Client
database string
}
func (c *MongoConnection) Client() *mongoDriver.Client {
return c.client
}
func (c *MongoConnection) Database() *mongoDriver.Database {
return c.client.Database(c.database)
}
func (c *MongoConnection) Disconnect(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
return c.client.Disconnect(ctx)
}
func (c *MongoConnection) Ping(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
return c.client.Ping(ctx, readpref.Primary())
}
// ConnectMongo returns a low-level MongoDB connection without constructing repositories.
func ConnectMongo(logger mlogger.Logger, config *Config) (*MongoConnection, error) {
if config == nil {
return nil, merrors.InvalidArgument("database configuration is nil")
}
if config.Driver != Mongo {
return nil, merrors.InvalidArgument("unsupported database driver: " + string(config.Driver))
}
client, _, settings, err := mongoimpl.ConnectClient(logger, config.Settings)
if err != nil {
return nil, err
}
return &MongoConnection{
client: client,
database: settings.Database,
}, nil
}

41
api/pkg/db/factory.go Normal file
View File

@@ -0,0 +1,41 @@
package db
import (
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/account"
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
"github.com/tech/sendico/pkg/db/invitation"
"github.com/tech/sendico/pkg/db/organization"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/refreshtokens"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/db/transaction"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger"
)
// Factory exposes high-level repositories used by application services.
type Factory interface {
NewRefreshTokensDB() (refreshtokens.DB, error)
NewAccountDB() (account.DB, error)
NewOrganizationDB() (organization.DB, error)
NewInvitationsDB() (invitation.DB, error)
NewRolesDB() (role.DB, error)
NewPoliciesDB() (policy.DB, error)
TransactionFactory() transaction.Factory
Permissions() auth.Provider
CloseConnection()
}
// NewConnection builds a Factory backed by the configured driver.
func NewConnection(logger mlogger.Logger, config *Config) (Factory, error) {
if config.Driver == Mongo {
return mongoimpl.NewConnection(logger, config.Settings)
}
return nil, merrors.InvalidArgument("unknown database driver: " + string(config.Driver))
}

View File

@@ -0,0 +1,12 @@
package indexable
import (
"context"
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type DB interface {
Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
}

View File

@@ -0,0 +1,30 @@
package accountdb
import (
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type AccountDB struct {
template.DBImp[*model.Account]
}
func Create(logger mlogger.Logger, db *mongo.Database) (*AccountDB, error) {
p := &AccountDB{
DBImp: *template.Create[*model.Account](logger, mservice.Accounts, db),
}
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "login", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create account database", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -0,0 +1,13 @@
package accountdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
)
func (db *AccountDB) GetByToken(ctx context.Context, email string) (*model.Account, error) {
var account model.Account
return &account, db.FindOne(ctx, repository.Query().Filter(repository.Field("verifyToken"), email), &account)
}

View File

@@ -0,0 +1,21 @@
package accountdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *AccountDB) GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error) {
filter := repository.Query().Comparison(repository.IDField(), builder.In, refs)
return mutil.GetObjects[model.Account](ctx, db.Logger, filter, nil, db.Repository)
}
func (db *AccountDB) GetByEmail(ctx context.Context, email string) (*model.Account, error) {
var account model.Account
return &account, db.FindOne(ctx, repository.Filter("login", email), &account)
}

View File

@@ -0,0 +1,99 @@
package archivable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// ArchivableDB implements archive management for entities with model.Archivable embedded
type ArchivableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getArchivable func(T) model.Archivable
}
// NewArchivableDB creates a new ArchivableDB instance
func NewArchivableDB[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
createEmpty func() T,
getArchivable func(T) model.Archivable,
) *ArchivableDB[T] {
return &ArchivableDB[T]{
repo: repo,
logger: logger,
createEmpty: createEmpty,
getArchivable: getArchivable,
}
}
// SetArchived sets the archived status of an entity
func (db *ArchivableDB[T]) SetArchived(ctx context.Context, objectRef primitive.ObjectID, archived bool) error {
// Get current object to check current archived status
obj := db.createEmpty()
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for setting archived status",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return err
}
// Extract archivable from the object
archivable := db.getArchivable(obj)
currentArchived := archivable.IsArchived()
if currentArchived == archived {
db.logger.Debug("No change needed - same archived status",
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return nil // No change needed
}
// Set the archived status
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
if err := db.repo.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to set archived status on object",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return err
}
db.logger.Debug("Successfully set archived status on object",
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return nil
}
// IsArchived checks if an entity is archived
func (db *ArchivableDB[T]) IsArchived(ctx context.Context, objectRef primitive.ObjectID) (bool, error) {
obj := db.createEmpty()
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for checking archived status",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef))
return false, err
}
archivable := db.getArchivable(obj)
return archivable.IsArchived(), nil
}
// Archive archives an entity (sets archived to true)
func (db *ArchivableDB[T]) Archive(ctx context.Context, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, objectRef, true)
}
// Unarchive unarchives an entity (sets archived to false)
func (db *ArchivableDB[T]) Unarchive(ctx context.Context, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, objectRef, false)
}

View File

@@ -0,0 +1,175 @@
//go:build integration
// +build integration
package archivable
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
// TestArchivableObject represents a test object with archivable functionality
type TestArchivableObject struct {
storable.Base `bson:",inline" json:",inline"`
model.ArchivableBase `bson:",inline" json:",inline"`
Name string `bson:"name" json:"name"`
}
func (t *TestArchivableObject) Collection() string {
return "testArchivableObject"
}
func (t *TestArchivableObject) GetArchivable() model.Archivable {
return &t.ArchivableBase
}
func TestArchivableDB(t *testing.T) {
ctx := context.Background()
// Start MongoDB container (stable)
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("test"),
mongodb.WithPassword("test"),
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
)
require.NoError(t, err)
defer func() {
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mongoContainer.Terminate(termCtx); err != nil {
t.Logf("Failed to terminate container: %v", err)
}
}()
// Get MongoDB connection string
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err)
// Connect to MongoDB
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI))
require.NoError(t, err)
defer func() {
if err := client.Disconnect(context.Background()); err != nil {
t.Logf("Failed to disconnect from MongoDB: %v", err)
}
}()
// Ping the database
err = client.Ping(ctx, nil)
require.NoError(t, err)
// Create repository
repo := repositoryimp.NewMongoRepository(client.Database("test_"+t.Name()), "testArchivableCollection")
// Create archivable DB
archivableDB := NewArchivableDB(
repo,
zap.NewNop(),
func() *TestArchivableObject { return &TestArchivableObject{} },
func(obj *TestArchivableObject) model.Archivable { return obj.GetArchivable() },
)
t.Run("SetArchived_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, true)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("SetArchived_NoChange", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, true)
require.NoError(t, err) // Should not error, just not change anything
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("SetArchived_Unarchive", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, false)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.False(t, result.IsArchived())
})
t.Run("IsArchived_True", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
require.NoError(t, err)
assert.True(t, isArchived)
})
t.Run("IsArchived_False", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
require.NoError(t, err)
assert.False(t, isArchived)
})
t.Run("Archive_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.Archive(ctx, obj.ID)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("Unarchive_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.Unarchive(ctx, obj.ID)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.False(t, result.IsArchived())
})
}

257
api/pkg/db/internal/mongo/db.go Executable file
View File

@@ -0,0 +1,257 @@
package mongo
import (
"context"
"os"
"github.com/mitchellh/mapstructure"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/account"
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
"github.com/tech/sendico/pkg/db/internal/mongo/rolesdb"
"github.com/tech/sendico/pkg/db/internal/mongo/transactionimp"
"github.com/tech/sendico/pkg/db/invitation"
"github.com/tech/sendico/pkg/db/organization"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/refreshtokens"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/db/transaction"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
mutil "github.com/tech/sendico/pkg/mutil/config"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.uber.org/zap"
)
// Config represents configuration
type Config struct {
Port *string `mapstructure:"port"`
PortEnv *string `mapstructure:"port_env"`
User *string `mapstructure:"user"`
UserEnv *string `mapstructure:"user_env"`
PasswordEnv string `mapstructure:"password_env"`
Database *string `mapstructure:"database"`
DatabaseEnv *string `mapstructure:"database_env"`
Host *string `mapstructure:"host"`
HostEnv *string `mapstructure:"host_env"`
AuthSource *string `mapstructure:"auth_source,omitempty"`
AuthSourceEnv *string `mapstructure:"auth_source_env,omitempty"`
AuthMechanism *string `mapstructure:"auth_mechanism,omitempty"`
AuthMechanismEnv *string `mapstructure:"auth_mechanism_env,omitempty"`
ReplicaSet *string `mapstructure:"replica_set,omitempty"`
ReplicaSetEnv *string `mapstructure:"replica_set_env,omitempty"`
Enforcer *auth.Config `mapstructure:"enforcer"`
}
type DBSettings struct {
Host string
Port string
User string
Password string
Database string
AuthSource string
AuthMechanism string
ReplicaSet string
}
func newProtectedDB[T any](
db *DB,
create func(ctx context.Context, logger mlogger.Logger, enforcer auth.Enforcer, pdb policy.DB, client *mongo.Database) (T, error),
) (T, error) {
pdb, err := db.NewPoliciesDB()
if err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
var zero T
return zero, err
}
return create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
}
func Config2DBSettings(logger mlogger.Logger, config *Config) *DBSettings {
p := new(DBSettings)
p.Port = mutil.GetConfigValue(logger, "port", "port_env", config.Port, config.PortEnv)
p.Database = mutil.GetConfigValue(logger, "database", "database_env", config.Database, config.DatabaseEnv)
p.Password = os.Getenv(config.PasswordEnv)
p.User = mutil.GetConfigValue(logger, "user", "user_env", config.User, config.UserEnv)
p.Host = mutil.GetConfigValue(logger, "host", "host_env", config.Host, config.HostEnv)
p.AuthSource = mutil.GetConfigValue(logger, "auth_source", "auth_source_env", config.AuthSource, config.AuthSourceEnv)
p.AuthMechanism = mutil.GetConfigValue(logger, "auth_mechanism", "auth_mechanism_env", config.AuthMechanism, config.AuthMechanismEnv)
p.ReplicaSet = mutil.GetConfigValue(logger, "replica_set", "replica_set_env", config.ReplicaSet, config.ReplicaSetEnv)
return p
}
func decodeConfig(logger mlogger.Logger, settings model.SettingsT) (*Config, *DBSettings, error) {
var config Config
if err := mapstructure.Decode(settings, &config); err != nil {
logger.Warn("Failed to decode settings", zap.Error(err), zap.Any("settings", settings))
return nil, nil, err
}
dbSettings := Config2DBSettings(logger, &config)
return &config, dbSettings, nil
}
func dialMongo(logger mlogger.Logger, dbSettings *DBSettings) (*mongo.Client, error) {
cred := options.Credential{
AuthMechanism: dbSettings.AuthMechanism,
AuthSource: dbSettings.AuthSource,
Username: dbSettings.User,
Password: dbSettings.Password,
}
dbURI := buildURI(dbSettings)
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(dbURI).SetAuth(cred))
if err != nil {
logger.Error("Unable to connect to database", zap.Error(err))
return nil, err
}
logger.Info("Connected successfully", zap.String("uri", dbURI))
if err := client.Ping(context.Background(), readpref.Primary()); err != nil {
logger.Error("Unable to ping database", zap.Error(err))
_ = client.Disconnect(context.Background())
return nil, err
}
return client, nil
}
func ConnectClient(logger mlogger.Logger, settings model.SettingsT) (*mongo.Client, *Config, *DBSettings, error) {
config, dbSettings, err := decodeConfig(logger, settings)
if err != nil {
return nil, nil, nil, err
}
client, err := dialMongo(logger, dbSettings)
if err != nil {
return nil, nil, nil, err
}
return client, config, dbSettings, nil
}
// DB represents the structure of the database
type DB struct {
logger mlogger.Logger
config *DBSettings
client *mongo.Client
enforcer auth.Enforcer
manager auth.Manager
pdb policy.DB
}
func (db *DB) db() *mongo.Database {
return db.client.Database(db.config.Database)
}
func (db *DB) NewAccountDB() (account.DB, error) {
return accountdb.Create(db.logger, db.db())
}
func (db *DB) NewOrganizationDB() (organization.DB, error) {
pdb, err := db.NewPoliciesDB()
if err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
return nil, err
}
organizationDB, err := organizationdb.Create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
if err != nil {
return nil, err
}
// Return the concrete type - interface mismatch will be handled at runtime
// TODO: Update organization.DB interface to match implementation signatures
return organizationDB, nil
}
func (db *DB) NewRefreshTokensDB() (refreshtokens.DB, error) {
return refreshtokensdb.Create(db.logger, db.db())
}
func (db *DB) NewInvitationsDB() (invitation.DB, error) {
return newProtectedDB(db, invitationdb.Create)
}
func (db *DB) NewPoliciesDB() (policy.DB, error) {
return db.pdb, nil
}
func (db *DB) NewRolesDB() (role.DB, error) {
return rolesdb.Create(db.logger, db.db())
}
func (db *DB) TransactionFactory() transaction.Factory {
return transactionimp.CreateFactory(db.client)
}
func (db *DB) Permissions() auth.Provider {
return db
}
func (db *DB) Manager() auth.Manager {
return db.manager
}
func (db *DB) Enforcer() auth.Enforcer {
return db.enforcer
}
func (db *DB) GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error) {
var policyDescription model.PolicyDescription
return &policyDescription, db.pdb.FindOne(ctx, repository.Filter("resourceTypes", resource), &policyDescription)
}
func (db *DB) CloseConnection() {
if err := db.client.Disconnect(context.Background()); err != nil {
db.logger.Warn("Failed to close connection", zap.Error(err))
}
db.logger.Info("Database connection closed")
}
// NewConnection creates a new database connection
func NewConnection(logger mlogger.Logger, settings model.SettingsT) (*DB, error) {
client, config, dbSettings, err := ConnectClient(logger, settings)
if err != nil {
return nil, err
}
db := &DB{
logger: logger.Named("db"),
config: dbSettings,
client: client,
}
cleanup := func(ctx context.Context) {
if err := client.Disconnect(ctx); err != nil {
logger.Warn("Failed to close MongoDB connection", zap.Error(err))
}
}
rdb, err := db.NewRolesDB()
if err != nil {
db.logger.Warn("Failed to create roles database", zap.Error(err))
cleanup(context.Background())
return nil, err
}
if db.pdb, err = policiesdb.Create(db.logger, db.db()); err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
cleanup(context.Background())
return nil, err
}
if db.enforcer, db.manager, err = auth.CreateAuth(logger, db.client, db.db(), db.pdb, rdb, config.Enforcer); err != nil {
db.logger.Warn("Failed to create permissions enforcer", zap.Error(err))
cleanup(context.Background())
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,144 @@
# Indexable Implementation (Refactored)
## Overview
This package provides a refactored implementation of the `indexable.DB` interface that uses `mutil.GetObjects` for better consistency with the existing codebase. The implementation has been moved to the mongo folder and includes a factory for project indexable in the pkg/db folder.
## Structure
### 1. `api/pkg/db/internal/mongo/indexable/indexable.go`
- **`ReorderTemplate[T]`**: Generic template function that uses `mutil.GetObjects` for fetching objects
- **`IndexableDB`**: Base struct for creating concrete implementations
- **Type-safe implementation**: Uses Go generics with proper type constraints
### 2. `api/pkg/db/project_indexable.go`
- **`ProjectIndexableDB`**: Factory implementation for Project objects
- **`NewProjectIndexableDB`**: Constructor function
- **`ReorderTemplate`**: Duplicate of the mongo version for convenience
## Key Changes from Previous Implementation
### 1. **Uses `mutil.GetObjects`**
```go
// Old implementation (manual cursor handling)
err = repo.FindManyByFilter(ctx, filter, func(cursor *mongo.Cursor) error {
var obj T
if err := cursor.Decode(&obj); err != nil {
return err
}
objects = append(objects, obj)
return nil
})
// New implementation (using mutil.GetObjects)
objects, err := mutil.GetObjects[T](
ctx,
logger,
filterFunc().
And(
repository.IndexOpFilter(minIdx, builder.Gte),
repository.IndexOpFilter(maxIdx, builder.Lte),
),
nil, nil, nil, // limit, offset, isArchived
repo,
)
```
### 2. **Moved to Mongo Folder**
- Location: `api/pkg/db/internal/mongo/indexable/`
- Consistent with other mongo implementations
- Better organization within the codebase
### 3. **Added Factory in pkg/db**
- Location: `api/pkg/db/project_indexable.go`
- Provides easy access to project indexable functionality
- Includes logger parameter for better error handling
## Usage
### Using the Factory (Recommended)
```go
import "github.com/tech/sendico/pkg/db"
// Create a project indexable DB
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
// Reorder a project
err := projectDB.Reorder(ctx, projectID, newIndex)
if err != nil {
// Handle error
}
```
### Using the Template Directly
```go
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
// Define helper functions
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
updateIndexable := func(p *model.Project, newIndex int) {
p.Index = newIndex
}
createEmpty := func() *model.Project {
return &model.Project{}
}
filterFunc := func() builder.Query {
return repository.OrgFilter(organizationRef)
}
// Use the template
err := indexable.ReorderTemplate(
ctx,
logger,
repo,
objectRef,
newIndex,
filterFunc,
getIndexable,
updateIndexable,
createEmpty,
)
```
## Benefits of Refactoring
1. **Consistency**: Uses `mutil.GetObjects` like other parts of the codebase
2. **Better Error Handling**: Includes logger parameter for proper error logging
3. **Organization**: Moved to appropriate folder structure
4. **Factory Pattern**: Easy-to-use factory for common use cases
5. **Type Safety**: Maintains compile-time type checking
6. **Performance**: Leverages existing optimized `mutil.GetObjects` implementation
## Testing
### Mongo Implementation Tests
```bash
go test ./db/internal/mongo/indexable -v
```
### Factory Tests
```bash
go test ./db -v
```
## Integration
The refactored implementation is ready for integration with existing project reordering APIs. The factory pattern makes it easy to add reordering functionality to any service that needs to reorder projects within an organization.
## Migration from Old Implementation
If you were using the old implementation:
1. **Update imports**: Change from `api/pkg/db/internal/indexable` to `api/pkg/db`
2. **Use factory**: Replace manual template usage with `NewProjectIndexableDB`
3. **Add logger**: Include a logger parameter in your constructor calls
4. **Update tests**: Use the new test structure if needed
The API remains the same, so existing code should work with minimal changes.

View File

@@ -0,0 +1,174 @@
# Indexable Usage Guide
## Generic Implementation for Any Indexable Struct
The implementation is now **generic** and supports **any struct that embeds `model.Indexable`**!
- **Interface**: `api/pkg/db/indexable.go` - defines the contract
- **Implementation**: `api/pkg/db/internal/mongo/indexable/` - generic implementation
- **Factory**: `api/pkg/db/project_indexable.go` - convenient factory for projects
## Usage
### 1. Using the Generic Implementation Directly
```go
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
// For any type that embeds model.Indexable, define helper functions:
createEmpty := func() *YourType {
return &YourType{}
}
getIndexable := func(obj *YourType) *model.Indexable {
return &obj.Indexable
}
// Create generic IndexableDB
indexableDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with single filter parameter
err := indexableDB.Reorder(ctx, objectID, newIndex, filter)
```
### 2. Using the Project Factory (Recommended for Projects)
```go
import "github.com/tech/sendico/pkg/db"
// Create project indexable DB (automatically applies org filter)
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
// Reorder project (org filter applied automatically)
err := projectDB.Reorder(ctx, projectID, newIndex, repository.Query())
// Reorder with additional filters (combined with org filter)
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
err := projectDB.Reorder(ctx, projectID, newIndex, additionalFilter)
```
## Examples for Different Types
### Project IndexableDB
```go
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
projectDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
orgFilter := repository.OrgFilter(organizationRef)
projectDB.Reorder(ctx, projectID, 2, orgFilter)
```
### Status IndexableDB
```go
createEmpty := func() *model.Status {
return &model.Status{}
}
getIndexable := func(s *model.Status) *model.Indexable {
return &s.Indexable
}
statusDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
statusDB.Reorder(ctx, statusID, 1, projectFilter)
```
### Task IndexableDB
```go
createEmpty := func() *model.Task {
return &model.Task{}
}
getIndexable := func(t *model.Task) *model.Indexable {
return &t.Indexable
}
taskDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
taskDB.Reorder(ctx, taskID, 3, statusFilter)
```
### Priority IndexableDB
```go
createEmpty := func() *model.Priority {
return &model.Priority{}
}
getIndexable := func(p *model.Priority) *model.Indexable {
return &p.Indexable
}
priorityDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
orgFilter := repository.OrgFilter(organizationRef)
priorityDB.Reorder(ctx, priorityID, 0, orgFilter)
```
### Global Reordering (No Filter)
```go
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
globalDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Reorders all items globally (empty filter)
globalDB.Reorder(ctx, objectID, 5, repository.Query())
```
## Key Features
### ✅ **Generic Support**
- Works with **any struct** that embeds `model.Indexable`
- Type-safe with compile-time checking
- No hardcoded types
### ✅ **Single Filter Parameter**
- **Simple**: Single `builder.Query` parameter instead of variadic `interface{}`
- **Flexible**: Can incorporate any combination of filters
- **Type-safe**: No runtime type assertions needed
### ✅ **Clean Architecture**
- Interface separated from implementation
- Generic implementation in internal package
- Easy-to-use factories for common types
## How It Works
### Generic Algorithm
1. **Get current index** using type-specific helper function
2. **If no change needed** → return early
3. **Apply filter** to scope affected items
4. **Shift affected items** using `PatchMany` with `$inc`
5. **Update target object** using `Patch` with `$set`
### Type-Safe Implementation
```go
type IndexableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getIndexable func(T) *model.Indexable
}
// Single filter parameter - clean and simple
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
```
## Benefits
**Generic** - Works with any Indexable struct
**Type Safe** - Compile-time type checking
**Simple** - Single filter parameter instead of variadic interface{}
**Efficient** - Uses patches, not full updates
**Clean** - Interface separated from implementation
That's it! **Generic, type-safe, and simple** reordering for any Indexable struct with a single filter parameter.

View File

@@ -0,0 +1,69 @@
package indexable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Example usage of the generic IndexableDB with different types
// Example 1: Using with Project
func ExampleProjectIndexableDB(repo repository.Repository, logger mlogger.Logger, organizationRef primitive.ObjectID) {
// Define helper functions for Project
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
// Create generic IndexableDB for Project
projectDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with organization filter
orgFilter := repository.OrgFilter(organizationRef)
projectDB.Reorder(context.Background(), primitive.NewObjectID(), 2, orgFilter)
}
// Example 3: Using with Task
func ExampleTaskIndexableDB(repo repository.Repository, logger mlogger.Logger, statusRef primitive.ObjectID) {
// Define helper functions for Task
createEmpty := func() *model.Task {
return &model.Task{}
}
getIndexable := func(t *model.Task) *model.Indexable {
return &t.Indexable
}
// Create generic IndexableDB for Task
taskDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with status filter
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
taskDB.Reorder(context.Background(), primitive.NewObjectID(), 3, statusFilter)
}
// Example 5: Using without any filter (global reordering)
func ExampleGlobalIndexableDB(repo repository.Repository, logger mlogger.Logger) {
// Define helper functions for any Indexable type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
// Create generic IndexableDB without filters
globalDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use without any filter - reorders all items globally
globalDB.Reorder(context.Background(), primitive.NewObjectID(), 5, repository.Query())
}

View File

@@ -0,0 +1,122 @@
package indexable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// IndexableDB implements db.IndexableDB interface with generic support
type IndexableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getIndexable func(T) *model.Indexable
}
// NewIndexableDB creates a new IndexableDB instance
func NewIndexableDB[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
createEmpty func() T,
getIndexable func(T) *model.Indexable,
) *IndexableDB[T] {
return &IndexableDB[T]{
repo: repo,
logger: logger,
createEmpty: createEmpty,
getIndexable: getIndexable,
}
}
// Reorder implements the db.IndexableDB interface with single filter parameter
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
// Get current object to find its index
obj := db.createEmpty()
err := db.repo.Get(ctx, objectRef, obj)
if err != nil {
db.logger.Error("Failed to get object for reordering",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("new_index", newIndex))
return err
}
// Extract index from the object
indexable := db.getIndexable(obj)
currentIndex := indexable.Index
if currentIndex == newIndex {
db.logger.Debug("No reordering needed - same index",
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex))
return nil // No change needed
}
// Simple reordering logic
if currentIndex < newIndex {
// Moving down: shift items between currentIndex+1 and newIndex up by -1
patch := repository.Patch().Inc(repository.IndexField(), -1)
reorderFilter := filter.
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
And(repository.IndexOpFilter(newIndex, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Error("Failed to shift objects during reordering (moving down)",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex),
zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving down)",
zap.String("object_ref", objectRef.Hex()),
zap.Int("updated_count", updatedCount))
} else {
// Moving up: shift items between newIndex and currentIndex-1 down by +1
patch := repository.Patch().Inc(repository.IndexField(), 1)
reorderFilter := filter.
And(repository.IndexOpFilter(newIndex, builder.Gte)).
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Error("Failed to shift objects during reordering (moving up)",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex),
zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving up)",
zap.String("object_ref", objectRef.Hex()),
zap.Int("updated_count", updatedCount))
}
// Update the target object to new index
patch := repository.Patch().Set(repository.IndexField(), newIndex)
err = db.repo.Patch(ctx, objectRef, patch)
if err != nil {
db.logger.Error("Failed to update target object index",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex))
return err
}
db.logger.Info("Successfully reordered object",
zap.String("object_ref", objectRef.Hex()),
zap.Int("old_index", currentIndex),
zap.Int("new_index", newIndex))
return nil
}

View File

@@ -0,0 +1,314 @@
//go:build integration
// +build integration
package indexable
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
func setupTestDB(t *testing.T) (repository.Repository, func()) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
db := client.Database("testdb")
repo := repository.CreateMongoRepository(db, "projects")
cleanup := func() {
disconnect(ctx, t, client)
terminate(ctx, t, mongoContainer)
}
return repo, cleanup
}
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
if err := client.Disconnect(ctx); err != nil {
t.Logf("failed to disconnect from MongoDB: %v", err)
}
}
func terminate(ctx context.Context, t *testing.T, container testcontainers.Container) {
if err := container.Terminate(ctx); err != nil {
t.Logf("failed to terminate MongoDB container: %v", err)
}
}
func TestIndexableDB_Reorder(t *testing.T) {
repo, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
organizationRef := primitive.NewObjectID()
logger := zap.NewNop()
// Create test projects with different indices
projects := []*model.Project{
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project A"},
Indexable: model.Indexable{Index: 0},
Mnemonic: "A",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project B"},
Indexable: model.Indexable{Index: 1},
Mnemonic: "B",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project C"},
Indexable: model.Indexable{Index: 2},
Mnemonic: "C",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project D"},
Indexable: model.Indexable{Index: 3},
Mnemonic: "D",
State: model.ProjectStateActive,
},
},
}
// Insert projects into database
for _, project := range projects {
project.ID = primitive.NewObjectID()
err := repo.Insert(ctx, project, nil)
require.NoError(t, err)
}
// Create helper functions for Project type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
t.Run("Reorder_NoChange", func(t *testing.T) {
// Test reordering to the same position (should be no-op)
err := indexableDB.Reorder(ctx, projects[1].ID, 1, repository.Query())
require.NoError(t, err)
// Verify indices haven't changed
var result model.Project
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
})
t.Run("Reorder_MoveDown", func(t *testing.T) {
// Move Project A (index 0) to index 2
err := indexableDB.Reorder(ctx, projects[0].ID, 2, repository.Query())
require.NoError(t, err)
// Verify the reordering:
// Project A should now be at index 2
// Project B should be at index 0
// Project C should be at index 1
// Project D should remain at index 3
var result model.Project
// Check Project A (moved to index 2)
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
// Check Project B (shifted to index 0)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
// Check Project C (shifted to index 1)
err = repo.Get(ctx, projects[2].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
// Check Project D (unchanged)
err = repo.Get(ctx, projects[3].ID, &result)
require.NoError(t, err)
assert.Equal(t, 3, result.Index)
})
t.Run("Reorder_MoveUp", func(t *testing.T) {
// Reset indices for this test
for i, project := range projects {
project.Index = i
err := repo.Update(ctx, project)
require.NoError(t, err)
}
// Move Project C (index 2) to index 0
err := indexableDB.Reorder(ctx, projects[2].ID, 0, repository.Query())
require.NoError(t, err)
// Verify the reordering:
// Project C should now be at index 0
// Project A should be at index 1
// Project B should be at index 2
// Project D should remain at index 3
var result model.Project
// Check Project C (moved to index 0)
err = repo.Get(ctx, projects[2].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
// Check Project A (shifted to index 1)
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
// Check Project B (shifted to index 2)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
// Check Project D (unchanged)
err = repo.Get(ctx, projects[3].ID, &result)
require.NoError(t, err)
assert.Equal(t, 3, result.Index)
})
t.Run("Reorder_WithFilter", func(t *testing.T) {
// Reset indices for this test
for i, project := range projects {
project.Index = i
err := repo.Update(ctx, project)
require.NoError(t, err)
}
// Test reordering with organization filter
orgFilter := repository.OrgFilter(organizationRef)
err := indexableDB.Reorder(ctx, projects[0].ID, 2, orgFilter)
require.NoError(t, err)
// Verify the reordering worked with filter
var result model.Project
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
})
}
func TestIndexableDB_EdgeCases(t *testing.T) {
repo, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
organizationRef := primitive.NewObjectID()
logger := zap.NewNop()
// Create a single project for edge case testing
project := &model.Project{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Test Project"},
Indexable: model.Indexable{Index: 0},
Mnemonic: "TEST",
State: model.ProjectStateActive,
},
}
project.ID = primitive.NewObjectID()
err := repo.Insert(ctx, project, nil)
require.NoError(t, err)
// Create helper functions for Project type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
t.Run("Reorder_SingleItem", func(t *testing.T) {
// Test reordering a single item (should work but have no effect)
err := indexableDB.Reorder(ctx, project.ID, 0, repository.Query())
require.NoError(t, err)
var result model.Project
err = repo.Get(ctx, project.ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
})
t.Run("Reorder_InvalidObjectID", func(t *testing.T) {
// Test reordering with an invalid object ID
invalidID := primitive.NewObjectID()
err := indexableDB.Reorder(ctx, invalidID, 1, repository.Query())
require.Error(t, err) // Should fail because object doesn't exist
})
}

View File

@@ -0,0 +1,12 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) Accept(ctx context.Context, invitationRef primitive.ObjectID) error {
return db.updateStatus(ctx, invitationRef, model.InvitationAccepted)
}

View File

@@ -0,0 +1,49 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// SetArchived sets the archived status of an invitation
// Invitation supports archiving through PermissionBound embedding ArchivableBase
func (db *InvitationDB) SetArchived(ctx context.Context, accountRef, organizationRef, invitationRef primitive.ObjectID, archived, cascade bool) error {
db.DBImp.Logger.Debug("Setting invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, invitationRef, model.ActionUpdate)
if err != nil {
db.DBImp.Logger.Warn("Failed to enforce archivation permission", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
if !res {
db.DBImp.Logger.Debug("Permission denied for archivation", mzap.ObjRef("invitation_ref", invitationRef))
return merrors.AccessDenied(db.Collection, string(model.ActionUpdate), invitationRef)
}
// Get the invitation first
var invitation model.Invitation
if err := db.Get(ctx, accountRef, invitationRef, &invitation); err != nil {
db.DBImp.Logger.Warn("Error retrieving invitation for archival", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
// Update the invitation's archived status
invitation.SetArchived(archived)
if err := db.Update(ctx, accountRef, &invitation); err != nil {
db.DBImp.Logger.Warn("Error updating invitation archived status", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
// Note: Currently no cascade dependencies for invitations
// If cascade is enabled, we could add logic here for any future dependencies
if cascade {
db.DBImp.Logger.Debug("Cascade archiving requested but no dependencies to archive for invitation", mzap.ObjRef("invitation_ref", invitationRef))
}
db.DBImp.Logger.Debug("Successfully set invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived))
return nil
}

View File

@@ -0,0 +1,24 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// DeleteCascade deletes an invitation
// Invitations don't have cascade dependencies, so this is a simple deletion
func (db *InvitationDB) DeleteCascade(ctx context.Context, accountRef, invitationRef primitive.ObjectID) error {
db.DBImp.Logger.Debug("Starting invitation cascade deletion", mzap.ObjRef("invitation_ref", invitationRef))
// Delete the invitation itself (no dependencies to cascade delete)
if err := db.Delete(ctx, accountRef, invitationRef); err != nil {
db.DBImp.Logger.Error("Error deleting invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
db.DBImp.Logger.Debug("Successfully deleted invitation", mzap.ObjRef("invitation_ref", invitationRef))
return nil
}

View File

@@ -0,0 +1,53 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type InvitationDB struct {
auth.ProtectedDBImp[*model.Invitation]
}
func Create(
ctx context.Context,
logger mlogger.Logger,
enforcer auth.Enforcer,
pdb policy.DB,
db *mongo.Database,
) (*InvitationDB, error) {
p, err := auth.CreateDBImp[*model.Invitation](ctx, logger, pdb, enforcer, mservice.Invitations, db)
if err != nil {
return nil, err
}
// unique email per organization
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: repository.OrgField().Build(), Sort: ri.Asc}, {Field: "description.email", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.DBImp.Logger.Error("Failed to create unique mnemonic index", zap.Error(err))
return nil, err
}
// ttl index
ttl := int32(0) // zero ttl means expiration on date preset when inserting data
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},
TTL: &ttl,
}); err != nil {
p.DBImp.Logger.Warn("Failed to create ttl index in the invitations", zap.Error(err))
return nil, err
}
return &InvitationDB{ProtectedDBImp: *p}, nil
}

View File

@@ -0,0 +1,12 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) Decline(ctx context.Context, invitationRef primitive.ObjectID) error {
return db.updateStatus(ctx, invitationRef, model.InvitationDeclined)
}

View File

@@ -0,0 +1,121 @@
package invitationdb
import (
"context"
"fmt"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
func (db *InvitationDB) GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error) {
roleField := repository.Field("role")
orgField := repository.Field("organization")
accField := repository.Field("account")
empField := repository.Field("employee")
regField := repository.Field("registrationAcc")
descEmailField := repository.Field("description").Dot("email")
pipeline := repository.Pipeline().
// 0) Filter to exactly the invitation(s) you want
Match(repository.IDFilter(invitationRef).And(repository.Filter("status", model.InvitationCreated))).
// 1) Lookup the role document
Lookup(
mservice.Roles,
repository.Field("roleRef"),
repository.IDField(),
roleField,
).
Unwind(repository.Ref(roleField)).
// 2) Lookup the organization document
Lookup(
mservice.Organizations,
repository.Field("organizationRef"),
repository.IDField(),
orgField,
).
Unwind(repository.Ref(orgField)).
// 3) Lookup the account document
Lookup(
mservice.Accounts,
repository.Field("inviterRef"),
repository.IDField(),
accField,
).
Unwind(repository.Ref(accField)).
/* 4) do we already have an account whose login == invitation.description ? */
Lookup(
mservice.Accounts,
descEmailField, // local field (invitation.description.email)
repository.Field("login"), // foreign field (account.login)
regField, // array: 0-length or ≥1
).
// 5) Projection
Project(
repository.SimpleAlias(
empField.Dot("description"),
repository.Ref(accField),
),
repository.SimpleAlias(
empField.Dot("avatarUrl"),
repository.Ref(accField.Dot("avatarUrl")),
),
repository.SimpleAlias(
orgField.Dot("description"),
repository.Ref(orgField),
),
repository.SimpleAlias(
orgField.Dot("logoUrl"),
repository.Ref(orgField.Dot("logoUrl")),
),
repository.SimpleAlias(
roleField,
repository.Ref(roleField),
),
repository.SimpleAlias(
repository.Field("invitation"), // ← left-hand side
repository.Ref(repository.Field("description")), // ← right-hand side (“$description”)
),
repository.SimpleAlias(
repository.Field("storable"), // ← left-hand side
repository.RootRef(), // ← right-hand side (“$description”)
),
repository.ProjectionExpr(
repository.Field("registrationRequired"),
repository.Eq(
repository.Size(repository.Value(repository.Ref(regField).Build())),
repository.Literal(0),
),
),
)
var res model.PublicInvitation
haveResult := false
decoder := func(cur *mongo.Cursor) error {
if haveResult {
// should never get here
db.DBImp.Logger.Warn("Unexpected extra invitation", mzap.ObjRef("invitation_ref", invitationRef))
return merrors.Internal("Unexpected extra invitation found by reference")
}
if e := cur.Decode(&res); e != nil {
db.DBImp.Logger.Warn("Failed to decode entity", zap.Error(e), zap.Any("data", cur.Current.String()))
return e
}
haveResult = true
return nil
}
if err := db.DBImp.Repository.Aggregate(ctx, pipeline, decoder); err != nil {
db.DBImp.Logger.Warn("Failed to execute aggregation pipeline", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return nil, err
}
if !haveResult {
db.DBImp.Logger.Warn("No results fetched", mzap.ObjRef("invitation_ref", invitationRef))
return nil, merrors.NoData(fmt.Sprintf("Invitation %s not found", invitationRef.Hex()))
}
return &res, nil
}

View File

@@ -0,0 +1,28 @@
package invitationdb
import (
"context"
"errors"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
mauth "github.com/tech/sendico/pkg/mutil/db/auth"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error) {
res, err := mauth.GetProtectedObjects[model.Invitation](
ctx,
db.DBImp.Logger,
accountRef, organizationRef, model.ActionRead,
repository.OrgFilter(organizationRef),
cursor,
db.Enforcer,
db.DBImp.Repository,
)
if errors.Is(err, merrors.ErrNoData) {
return []model.Invitation{}, nil
}
return res, err
}

View File

@@ -0,0 +1,26 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
func (db *InvitationDB) updateStatus(ctx context.Context, invitationRef primitive.ObjectID, newStatus model.InvitationStatus) error {
// db.DBImp.Up
var inv model.Invitation
if err := db.DBImp.FindOne(ctx, repository.IDFilter(invitationRef), &inv); err != nil {
db.DBImp.Logger.Warn("Failed to fetch invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
return err
}
inv.Status = newStatus
if err := db.DBImp.Update(ctx, &inv); err != nil {
db.DBImp.Logger.Warn("Failed to update invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
return err
}
return nil
}

View File

@@ -0,0 +1,22 @@
package mongo
import (
"net/url"
)
func buildURI(s *DBSettings) string {
u := &url.URL{
Scheme: "mongodb",
Host: s.Host,
Path: "/" + url.PathEscape(s.Database), // /my%20db
}
q := url.Values{}
if s.ReplicaSet != "" {
q.Set("replicaSet", s.ReplicaSet)
}
u.RawQuery = q.Encode()
return u.String()
}

View File

@@ -0,0 +1,32 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// SetArchived sets the archived status of an organization and optionally cascades to projects, tasks, comments, and reactions
func (db *OrganizationDB) SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error {
db.DBImp.Logger.Debug("Setting organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
// Get the organization first
var organization model.Organization
if err := db.Get(ctx, accountRef, organizationRef, &organization); err != nil {
db.DBImp.Logger.Warn("Error retrieving organization for archival", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
// Update the organization's archived status
organization.SetArchived(archived)
if err := db.Update(ctx, accountRef, &organization); err != nil {
db.DBImp.Logger.Warn("Error updating organization archived status", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
db.DBImp.Logger.Debug("Successfully set organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived))
return nil
}

View File

@@ -0,0 +1,23 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// DeleteCascade deletes an organization and all its related data (projects, tasks, comments, reactions, statuses)
func (db *OrganizationDB) DeleteCascade(ctx context.Context, organizationRef primitive.ObjectID) error {
db.DBImp.Logger.Debug("Starting organization deletion with projects", mzap.ObjRef("organization_ref", organizationRef))
// Delete the organization itself
if err := db.Unprotected().Delete(ctx, organizationRef); err != nil {
db.DBImp.Logger.Warn("Error deleting organization", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
db.DBImp.Logger.Debug("Successfully deleted organization with projects", mzap.ObjRef("organization_ref", organizationRef))
return nil
}

View File

@@ -0,0 +1,19 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *OrganizationDB) Create(ctx context.Context, _, _ primitive.ObjectID, org *model.Organization) error {
if org == nil {
return merrors.InvalidArgument("Organization object is nil")
}
org.SetID(primitive.NewObjectID())
// Organizaiton reference must be set to the same value as own organization reference
org.SetOrganizationRef(*org.GetID())
return db.DBImp.Create(ctx, org)
}

View File

@@ -0,0 +1,34 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
)
type OrganizationDB struct {
auth.ProtectedDBImp[*model.Organization]
}
func Create(ctx context.Context,
logger mlogger.Logger,
enforcer auth.Enforcer,
pdb policy.DB,
db *mongo.Database,
) (*OrganizationDB, error) {
p, err := auth.CreateDBImp[*model.Organization](ctx, logger, pdb, enforcer, mservice.Organizations, db)
if err != nil {
return nil, err
}
res := &OrganizationDB{
ProtectedDBImp: *p,
}
p.DBImp.SetDeleter(res.DeleteCascade)
return res, nil
}

Some files were not shown because too many files have changed in this diff Show More