service backend
This commit is contained in:
BIN
api/pkg/.DS_Store
vendored
Normal file
BIN
api/pkg/.DS_Store
vendored
Normal file
Binary file not shown.
6
api/pkg/.gitignore
vendored
Normal file
6
api/pkg/.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
proto/billing
|
||||
proto/common
|
||||
proto/chain
|
||||
proto/ledger
|
||||
proto/oracle
|
||||
proto/payments
|
||||
36
api/pkg/api/http/methods.go
Normal file
36
api/pkg/api/http/methods.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package api
|
||||
|
||||
import "fmt"
|
||||
|
||||
type HTTPMethod int
|
||||
|
||||
const (
|
||||
Get HTTPMethod = iota
|
||||
Post
|
||||
Put
|
||||
Patch
|
||||
Delete
|
||||
Options
|
||||
Head
|
||||
)
|
||||
|
||||
func HTTPMethod2String(method HTTPMethod) string {
|
||||
switch method {
|
||||
case Get:
|
||||
return "GET"
|
||||
case Post:
|
||||
return "POST"
|
||||
case Put:
|
||||
return "PUT"
|
||||
case Delete:
|
||||
return "DELETE"
|
||||
case Patch:
|
||||
return "PATCH"
|
||||
case Options:
|
||||
return "OPTIONS"
|
||||
case Head:
|
||||
return "HEAD"
|
||||
default:
|
||||
return fmt.Sprintf("unknown: %d", method)
|
||||
}
|
||||
}
|
||||
205
api/pkg/api/http/response/response.go
Normal file
205
api/pkg/api/http/response/response.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
api "github.com/tech/sendico/pkg/api/http"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// BaseResponse is a general structure for all API responses.
|
||||
type BaseResponse struct {
|
||||
Status string `json:"status"` // "success" or "error"
|
||||
Data any `json:"data"` // The actual data payload or the error details
|
||||
}
|
||||
|
||||
// ErrorResponse provides more details about an error.
|
||||
type ErrorResponse struct {
|
||||
Code int `json:"code"` // A unique identifier for the error type, useful for client handling
|
||||
Error string `json:"error"`
|
||||
Source string `json:"source"`
|
||||
Details string `json:"details"` // Additional details or hints about the error, if necessary
|
||||
}
|
||||
|
||||
func errMessage(err error) string {
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func logRequest(logger mlogger.Logger, r *http.Request, message string) {
|
||||
logger.Debug(
|
||||
message,
|
||||
zap.String("host", r.Host),
|
||||
zap.String("address", r.RemoteAddr),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("request_uri", r.RequestURI),
|
||||
zap.String("proto", r.Proto),
|
||||
zap.String("user_agent", r.UserAgent()),
|
||||
)
|
||||
}
|
||||
|
||||
func writeJSON(logger mlogger.Logger, w http.ResponseWriter, r *http.Request, code int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(code)
|
||||
if err := json.NewEncoder(w).Encode(&payload); err != nil {
|
||||
logger.Warn("Failed to encode JSON response",
|
||||
zap.Error(err),
|
||||
zap.Any("response", payload),
|
||||
zap.String("host", r.Host),
|
||||
zap.String("address", r.RemoteAddr),
|
||||
zap.String("method", r.Method),
|
||||
zap.String("request_uri", r.RequestURI),
|
||||
zap.String("proto", r.Proto),
|
||||
zap.String("user_agent", r.UserAgent()))
|
||||
}
|
||||
}
|
||||
|
||||
func errorf(
|
||||
logger mlogger.Logger,
|
||||
w http.ResponseWriter, r *http.Request,
|
||||
source mservice.Type, code int, message, details string,
|
||||
) {
|
||||
logRequest(logger, r, message)
|
||||
|
||||
errorMessage := BaseResponse{
|
||||
Status: api.MSError,
|
||||
Data: ErrorResponse{
|
||||
Code: code,
|
||||
Details: details,
|
||||
Source: source,
|
||||
Error: message,
|
||||
},
|
||||
}
|
||||
|
||||
writeJSON(logger, w, r, code, errorMessage)
|
||||
}
|
||||
|
||||
func Accepted(logger mlogger.Logger, data any) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := BaseResponse{
|
||||
Status: api.MSProcessed,
|
||||
Data: data,
|
||||
}
|
||||
writeJSON(logger, w, r, http.StatusAccepted, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func Ok(logger mlogger.Logger, data any) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := BaseResponse{
|
||||
Status: api.MSSuccess,
|
||||
Data: data,
|
||||
}
|
||||
writeJSON(logger, w, r, http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func Created(logger mlogger.Logger, data any) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := BaseResponse{
|
||||
Status: api.MSSuccess,
|
||||
Data: data,
|
||||
}
|
||||
writeJSON(logger, w, r, http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func Auto(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
|
||||
if err == nil {
|
||||
return Success(logger)
|
||||
}
|
||||
if errors.Is(err, merrors.ErrAccessDenied) {
|
||||
return AccessDenied(logger, source, errMessage(err))
|
||||
}
|
||||
if errors.Is(err, merrors.ErrDataConflict) {
|
||||
return DataConflict(logger, source, errMessage(err))
|
||||
}
|
||||
if errors.Is(err, merrors.ErrInvalidArg) {
|
||||
return BadRequest(logger, source, "invalid_argument", errMessage(err))
|
||||
}
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
return NotFound(logger, source, errMessage(err))
|
||||
}
|
||||
if errors.Is(err, merrors.ErrUnauthorized) {
|
||||
return Unauthorized(logger, source, errMessage(err))
|
||||
}
|
||||
return Internal(logger, source, err)
|
||||
}
|
||||
|
||||
func Internal(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusInternalServerError, "internal_error", errMessage(err))
|
||||
}
|
||||
}
|
||||
|
||||
func NotImplemented(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusNotImplemented, "not_implemented", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func BadRequest(logger mlogger.Logger, source mservice.Type, err, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusBadRequest, err, hint)
|
||||
}
|
||||
}
|
||||
|
||||
func BadQueryParam(logger mlogger.Logger, source mservice.Type, param string, err error) http.HandlerFunc {
|
||||
return BadRequest(logger, source, "invalid_query_parameter", fmt.Sprintf("Failed to parse '%s': %v", param, err))
|
||||
}
|
||||
|
||||
func BadReference(logger mlogger.Logger, source mservice.Type, refName, refVal string, err error) http.HandlerFunc {
|
||||
return BadRequest(logger, source, "broken_reference",
|
||||
fmt.Sprintf("broken object reference: %s = %s, error: %v", refName, refVal, err))
|
||||
}
|
||||
|
||||
func BadPayload(logger mlogger.Logger, source mservice.Type, err error) http.HandlerFunc {
|
||||
return BadRequest(logger, source, "broken_payload",
|
||||
fmt.Sprintf("broken '%s' object payload, error: %v", source, err))
|
||||
}
|
||||
|
||||
func DataConflict(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusConflict, "data_conflict", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func Error(logger mlogger.Logger, source mservice.Type, code int, errType, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, code, errType, hint)
|
||||
}
|
||||
}
|
||||
|
||||
func AccessDenied(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return Error(logger, source, http.StatusForbidden, "access_denied", hint)
|
||||
}
|
||||
|
||||
func Forbidden(logger mlogger.Logger, source mservice.Type, errType, hint string) http.HandlerFunc {
|
||||
return Error(logger, source, http.StatusForbidden, errType, hint)
|
||||
}
|
||||
|
||||
func LicenseRequired(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusPaymentRequired, "license_required", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func Unauthorized(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusUnauthorized, "unauthorized", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func NotFound(logger mlogger.Logger, source mservice.Type, hint string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
errorf(logger, w, r, source, http.StatusNotFound, "not_found", hint)
|
||||
}
|
||||
}
|
||||
19
api/pkg/api/http/response/result.go
Normal file
19
api/pkg/api/http/response/result.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Result bool `json:"result"`
|
||||
}
|
||||
|
||||
func Success(logger mlogger.Logger) http.HandlerFunc {
|
||||
return Ok(logger, Result{Result: true})
|
||||
}
|
||||
|
||||
func Failed(logger mlogger.Logger) http.HandlerFunc {
|
||||
return Accepted(logger, Result{Result: false})
|
||||
}
|
||||
8
api/pkg/api/http/status.go
Normal file
8
api/pkg/api/http/status.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package api
|
||||
|
||||
const (
|
||||
MSSuccess string = "success"
|
||||
MSProcessed string = "processed"
|
||||
MSError string = "error"
|
||||
MSRequest string = "request"
|
||||
)
|
||||
61
api/pkg/api/routers/grpc.go
Normal file
61
api/pkg/api/routers/grpc.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package routers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/tech/sendico/pkg/api/routers/internal/grpcimp"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type (
|
||||
GRPCServiceRegistration = func(grpc.ServiceRegistrar)
|
||||
)
|
||||
|
||||
type GRPC interface {
|
||||
Register(registration GRPCServiceRegistration) error
|
||||
Start(ctx context.Context) error
|
||||
Finish(ctx context.Context) error
|
||||
Addr() net.Addr
|
||||
Done() <-chan error
|
||||
}
|
||||
|
||||
type (
|
||||
GRPCConfig = grpcimp.Config
|
||||
GRPCTLSConfig = grpcimp.TLSConfig
|
||||
)
|
||||
|
||||
type GRPCOption func(*grpcimp.Options)
|
||||
|
||||
func WithUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) GRPCOption {
|
||||
return func(o *grpcimp.Options) {
|
||||
o.UnaryInterceptors = append(o.UnaryInterceptors, interceptors...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) GRPCOption {
|
||||
return func(o *grpcimp.Options) {
|
||||
o.StreamInterceptors = append(o.StreamInterceptors, interceptors...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithListener(listener net.Listener) GRPCOption {
|
||||
return func(o *grpcimp.Options) {
|
||||
o.Listener = listener
|
||||
}
|
||||
}
|
||||
|
||||
func WithServerOptions(opts ...grpc.ServerOption) GRPCOption {
|
||||
return func(o *grpcimp.Options) {
|
||||
o.ServerOptions = append(o.ServerOptions, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func NewGRPCRouter(logger mlogger.Logger, config *GRPCConfig, opts ...GRPCOption) (GRPC, error) {
|
||||
options := &grpcimp.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return grpcimp.NewRouter(logger, config, options)
|
||||
}
|
||||
149
api/pkg/api/routers/gsresponse/response.go
Normal file
149
api/pkg/api/routers/gsresponse/response.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package gsresponse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// Responder produces a response or a gRPC status error when executed.
|
||||
type Responder[T any] func(ctx context.Context) (*T, error)
|
||||
|
||||
func message(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func Success[T any](resp *T) Responder[T] {
|
||||
return func(context.Context) (*T, error) {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func Empty[T any]() Responder[T] {
|
||||
return func(context.Context) (*T, error) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func Error[T any](logger mlogger.Logger, service mservice.Type, code codes.Code, hint string, err error) Responder[T] {
|
||||
return func(ctx context.Context) (*T, error) {
|
||||
fields := []zap.Field{
|
||||
zap.String("service", string(service)),
|
||||
zap.String("status_code", code.String()),
|
||||
}
|
||||
if hint != "" {
|
||||
fields = append(fields, zap.String("error_hint", hint))
|
||||
}
|
||||
if err != nil {
|
||||
fields = append(fields, zap.Error(err))
|
||||
}
|
||||
logFn := logger.Warn
|
||||
switch code {
|
||||
case codes.Internal, codes.DataLoss, codes.Unavailable:
|
||||
logFn = logger.Error
|
||||
}
|
||||
logFn("gRPC request failed", fields...)
|
||||
|
||||
msg := message(err)
|
||||
switch {
|
||||
case hint == "" && msg == "":
|
||||
return nil, status.Error(code, code.String())
|
||||
case hint == "":
|
||||
return nil, status.Error(code, msg)
|
||||
case msg == "":
|
||||
return nil, status.Error(code, hint)
|
||||
default:
|
||||
return nil, status.Error(code, fmt.Sprintf("%s: %s", hint, msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Internal[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.Internal, "internal_error", err)
|
||||
}
|
||||
|
||||
func InvalidArgument[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.InvalidArgument, "invalid_argument", err)
|
||||
}
|
||||
|
||||
func NotFound[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.NotFound, "not_found", err)
|
||||
}
|
||||
|
||||
func Unauthorized[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.Unauthenticated, "unauthorized", err)
|
||||
}
|
||||
|
||||
func PermissionDenied[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.PermissionDenied, "access_denied", err)
|
||||
}
|
||||
|
||||
func FailedPrecondition[T any](logger mlogger.Logger, service mservice.Type, hint string, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.FailedPrecondition, hint, err)
|
||||
}
|
||||
|
||||
func Conflict[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.Aborted, "conflict", err)
|
||||
}
|
||||
|
||||
func DeadlineExceeded[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.DeadlineExceeded, "deadline_exceeded", err)
|
||||
}
|
||||
|
||||
func Unavailable[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.Unavailable, "service_unavailable", err)
|
||||
}
|
||||
|
||||
func Unimplemented[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.Unimplemented, "not_implemented", err)
|
||||
}
|
||||
|
||||
func AlreadyExists[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
return Error[T](logger, service, codes.AlreadyExists, "already_exists", err)
|
||||
}
|
||||
|
||||
func Auto[T any](logger mlogger.Logger, service mservice.Type, err error) Responder[T] {
|
||||
switch {
|
||||
case err == nil:
|
||||
return Empty[T]()
|
||||
case errors.Is(err, merrors.ErrInvalidArg):
|
||||
return InvalidArgument[T](logger, service, err)
|
||||
case errors.Is(err, merrors.ErrAccessDenied):
|
||||
return PermissionDenied[T](logger, service, err)
|
||||
case errors.Is(err, merrors.ErrNoData):
|
||||
return NotFound[T](logger, service, err)
|
||||
case errors.Is(err, merrors.ErrUnauthorized):
|
||||
return Unauthorized[T](logger, service, err)
|
||||
case errors.Is(err, merrors.ErrDataConflict):
|
||||
return Conflict[T](logger, service, err)
|
||||
default:
|
||||
return Internal[T](logger, service, err)
|
||||
}
|
||||
}
|
||||
|
||||
func Execute[T any](ctx context.Context, responder Responder[T]) (*T, error) {
|
||||
if responder == nil {
|
||||
return nil, status.Error(codes.Internal, "missing responder")
|
||||
}
|
||||
return responder(ctx)
|
||||
}
|
||||
|
||||
func Unary[TReq any, TResp any](logger mlogger.Logger, service mservice.Type, handler func(context.Context, *TReq) Responder[TResp]) func(context.Context, *TReq) (*TResp, error) {
|
||||
return func(ctx context.Context, req *TReq) (*TResp, error) {
|
||||
if handler == nil {
|
||||
return nil, status.Error(codes.Internal, "missing handler")
|
||||
}
|
||||
responder := handler(ctx, req)
|
||||
return Execute(ctx, responder)
|
||||
}
|
||||
}
|
||||
75
api/pkg/api/routers/gsresponse/response_test.go
Normal file
75
api/pkg/api/routers/gsresponse/response_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package gsresponse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type testRequest struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
type testResponse struct {
|
||||
Result string
|
||||
}
|
||||
|
||||
func TestUnarySuccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
handler := func(ctx context.Context, req *testRequest) Responder[testResponse] {
|
||||
require.NotNil(t, req)
|
||||
require.Equal(t, "hello", req.Value)
|
||||
resp := &testResponse{Result: "ok"}
|
||||
return Success(resp)
|
||||
}
|
||||
|
||||
unary := Unary[testRequest, testResponse](logger, mservice.Type("test"), handler)
|
||||
resp, err := unary(context.Background(), &testRequest{Value: "hello"})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, "ok", resp.Result)
|
||||
}
|
||||
|
||||
func TestAutoMappings(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
service := mservice.Type("test")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code codes.Code
|
||||
}{
|
||||
{"invalid_argument", merrors.InvalidArgument("bad"), codes.InvalidArgument},
|
||||
{"access_denied", merrors.AccessDenied("object", "action", primitive.NilObjectID), codes.PermissionDenied},
|
||||
{"not_found", merrors.NoData("missing"), codes.NotFound},
|
||||
{"unauthorized", fmt.Errorf("%w: %s", merrors.ErrUnauthorized, "bad"), codes.Unauthenticated},
|
||||
{"conflict", merrors.DataConflict("conflict"), codes.Aborted},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
responder := Auto[testResponse](logger, service, tc.err)
|
||||
_, err := responder(context.Background())
|
||||
require.Error(t, err)
|
||||
st, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tc.code, st.Code())
|
||||
})
|
||||
}
|
||||
|
||||
responder := Auto[testResponse](logger, service, errors.New("boom"))
|
||||
_, err := responder(context.Background())
|
||||
require.Error(t, err)
|
||||
st, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, codes.Internal, st.Code())
|
||||
}
|
||||
17
api/pkg/api/routers/health.go
Normal file
17
api/pkg/api/routers/health.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package routers
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/tech/sendico/pkg/api/routers/health"
|
||||
"github.com/tech/sendico/pkg/api/routers/internal/healthimp"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
type Health interface {
|
||||
SetStatus(status health.ServiceStatus)
|
||||
Finish()
|
||||
}
|
||||
|
||||
func NewHealthRouter(logger mlogger.Logger, router chi.Router, endpoint string) (Health, error) {
|
||||
return healthimp.NewRouter(logger, router, endpoint), nil
|
||||
}
|
||||
10
api/pkg/api/routers/health/status.go
Normal file
10
api/pkg/api/routers/health/status.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package health
|
||||
|
||||
type ServiceStatus string
|
||||
|
||||
const (
|
||||
SSCreated ServiceStatus = "created"
|
||||
SSStarting ServiceStatus = "starting"
|
||||
SSRunning ServiceStatus = "ok"
|
||||
SSTerminating ServiceStatus = "deactivating"
|
||||
)
|
||||
18
api/pkg/api/routers/internal/grpcimp/config.go
Normal file
18
api/pkg/api/routers/internal/grpcimp/config.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package grpcimp
|
||||
|
||||
type Config struct {
|
||||
Network string `yaml:"network"`
|
||||
Address string `yaml:"address"`
|
||||
EnableReflection bool `yaml:"enable_reflection"`
|
||||
EnableHealth bool `yaml:"enable_health"`
|
||||
MaxRecvMsgSize int `yaml:"max_recv_msg_size"`
|
||||
MaxSendMsgSize int `yaml:"max_send_msg_size"`
|
||||
TLS *TLSConfig `yaml:"tls"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
CAFile string `yaml:"ca_file"`
|
||||
RequireClientCert bool `yaml:"require_client_cert"`
|
||||
}
|
||||
103
api/pkg/api/routers/internal/grpcimp/metrics.go
Normal file
103
api/pkg/api/routers/internal/grpcimp/metrics.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package grpcimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
metricsOnce sync.Once
|
||||
grpcServerRequestsTotal *prometheus.CounterVec
|
||||
grpcServerLatency *prometheus.HistogramVec
|
||||
)
|
||||
|
||||
func initPrometheusMetrics() {
|
||||
metricsOnce.Do(func() {
|
||||
grpcServerRequestsTotal = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "grpc_server_requests_total",
|
||||
Help: "Total number of gRPC requests handled by the server.",
|
||||
},
|
||||
[]string{"grpc_service", "grpc_method", "grpc_type", "grpc_code"},
|
||||
)
|
||||
|
||||
grpcServerLatency = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "grpc_server_handling_seconds",
|
||||
Help: "Duration of gRPC requests handled by the server.",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"grpc_service", "grpc_method", "grpc_type", "grpc_code"},
|
||||
)
|
||||
|
||||
prometheus.MustRegister(grpcServerRequestsTotal, grpcServerLatency)
|
||||
})
|
||||
}
|
||||
|
||||
func prometheusUnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
initPrometheusMetrics()
|
||||
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
start := time.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
recordMetrics(info.FullMethod, "unary", time.Since(start), err)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func prometheusStreamInterceptor() grpc.StreamServerInterceptor {
|
||||
initPrometheusMetrics()
|
||||
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
start := time.Now()
|
||||
err := handler(srv, ss)
|
||||
|
||||
recordMetrics(info.FullMethod, streamType(info), time.Since(start), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func streamType(info *grpc.StreamServerInfo) string {
|
||||
if info == nil {
|
||||
return "stream"
|
||||
}
|
||||
if info.IsServerStream && info.IsClientStream {
|
||||
return "bidi"
|
||||
}
|
||||
if info.IsServerStream {
|
||||
return "server_stream"
|
||||
}
|
||||
if info.IsClientStream {
|
||||
return "client_stream"
|
||||
}
|
||||
return "stream"
|
||||
}
|
||||
|
||||
func recordMetrics(fullMethod string, callType string, duration time.Duration, err error) {
|
||||
service, method := splitMethod(fullMethod)
|
||||
code := status.Code(err).String()
|
||||
|
||||
grpcServerRequestsTotal.WithLabelValues(service, method, callType, code).Inc()
|
||||
grpcServerLatency.WithLabelValues(service, method, callType, code).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
func splitMethod(fullMethod string) (string, string) {
|
||||
if fullMethod == "" {
|
||||
return "unknown", "unknown"
|
||||
}
|
||||
if fullMethod[0] == '/' {
|
||||
fullMethod = fullMethod[1:]
|
||||
}
|
||||
parts := strings.Split(fullMethod, "/")
|
||||
if len(parts) < 2 {
|
||||
return fullMethod, "unknown"
|
||||
}
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
14
api/pkg/api/routers/internal/grpcimp/options.go
Normal file
14
api/pkg/api/routers/internal/grpcimp/options.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package grpcimp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
UnaryInterceptors []grpc.UnaryServerInterceptor
|
||||
StreamInterceptors []grpc.StreamServerInterceptor
|
||||
ServerOptions []grpc.ServerOption
|
||||
Listener net.Listener
|
||||
}
|
||||
293
api/pkg/api/routers/internal/grpcimp/router.go
Normal file
293
api/pkg/api/routers/internal/grpcimp/router.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package grpcimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/health"
|
||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
type routerError string
|
||||
|
||||
func (e routerError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type routerErrorWithCause struct {
|
||||
message string
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *routerErrorWithCause) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.cause == nil {
|
||||
return e.message
|
||||
}
|
||||
return e.message + ": " + e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *routerErrorWithCause) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func newRouterErrorWithCause(message string, cause error) error {
|
||||
return &routerErrorWithCause{
|
||||
message: message,
|
||||
cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
errMsgAlreadyStarted = "grpc router already started"
|
||||
errMsgListenFailed = "failed to listen on requested address"
|
||||
errMsgNilContext = "nil context"
|
||||
errMsgTLSMissingCertAndKey = "tls configuration requires cert_file and key_file"
|
||||
errMsgLoadServerCertificate = "failed to load server certificate"
|
||||
errMsgReadCAFile = "failed to read CA file"
|
||||
errMsgAppendCACertificates = "failed to append CA certificates"
|
||||
errMsgClientCertRequiresCAFile = "client certificate verification requested but ca_file is empty"
|
||||
)
|
||||
|
||||
var (
|
||||
errAlreadyStarted = routerError(errMsgAlreadyStarted)
|
||||
errNilContext = routerError(errMsgNilContext)
|
||||
errTLSMissingCertAndKey = routerError(errMsgTLSMissingCertAndKey)
|
||||
errAppendCACertificates = routerError(errMsgAppendCACertificates)
|
||||
errClientCertRequiresCAFile = routerError(errMsgClientCertRequiresCAFile)
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
logger mlogger.Logger
|
||||
config Config
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
options *Options
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
serveErr chan error
|
||||
healthSrv *health.Server
|
||||
}
|
||||
|
||||
func NewRouter(logger mlogger.Logger, cfg *Config, opts *Options) (*Router, error) {
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
network := cfg.Network
|
||||
if network == "" {
|
||||
network = "tcp"
|
||||
}
|
||||
address := cfg.Address
|
||||
if address == "" {
|
||||
address = ":0"
|
||||
}
|
||||
|
||||
listener := opts.Listener
|
||||
var err error
|
||||
if listener == nil {
|
||||
listener, err = net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, newRouterErrorWithCause(errMsgListenFailed, err)
|
||||
}
|
||||
}
|
||||
|
||||
serverOpts := make([]grpc.ServerOption, 0, len(opts.ServerOptions)+4)
|
||||
serverOpts = append(serverOpts, opts.ServerOptions...)
|
||||
|
||||
if cfg.MaxRecvMsgSize > 0 {
|
||||
serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.MaxRecvMsgSize))
|
||||
}
|
||||
if cfg.MaxSendMsgSize > 0 {
|
||||
serverOpts = append(serverOpts, grpc.MaxSendMsgSize(cfg.MaxSendMsgSize))
|
||||
}
|
||||
|
||||
if creds, err := configureTLS(cfg.TLS); err != nil {
|
||||
return nil, err
|
||||
} else if creds != nil {
|
||||
serverOpts = append(serverOpts, grpc.Creds(creds))
|
||||
}
|
||||
|
||||
unaryInterceptors := append([]grpc.UnaryServerInterceptor{prometheusUnaryInterceptor()}, opts.UnaryInterceptors...)
|
||||
streamInterceptors := append([]grpc.StreamServerInterceptor{prometheusStreamInterceptor()}, opts.StreamInterceptors...)
|
||||
|
||||
if len(unaryInterceptors) > 0 {
|
||||
serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(unaryInterceptors...))
|
||||
}
|
||||
if len(streamInterceptors) > 0 {
|
||||
serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(streamInterceptors...))
|
||||
}
|
||||
|
||||
srv := grpc.NewServer(serverOpts...)
|
||||
r := &Router{
|
||||
logger: logger.Named("grpc"),
|
||||
config: *cfg,
|
||||
server: srv,
|
||||
listener: listener,
|
||||
options: opts,
|
||||
serveErr: make(chan error, 1),
|
||||
}
|
||||
|
||||
if cfg.EnableReflection {
|
||||
reflection.Register(srv)
|
||||
}
|
||||
if cfg.EnableHealth {
|
||||
r.healthSrv = health.NewServer()
|
||||
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
|
||||
healthpb.RegisterHealthServer(srv, r.healthSrv)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Router) Register(registration func(grpc.ServiceRegistrar)) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.started {
|
||||
return errAlreadyStarted
|
||||
}
|
||||
|
||||
registration(r.server)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) Start(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errNilContext
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
if r.started {
|
||||
r.mu.Unlock()
|
||||
return errAlreadyStarted
|
||||
}
|
||||
r.started = true
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.healthSrv != nil {
|
||||
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
r.logger.Info("Context cancelled, stopping gRPC server")
|
||||
r.server.GracefulStop()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := r.server.Serve(r.listener)
|
||||
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||
select {
|
||||
case r.serveErr <- err:
|
||||
default:
|
||||
r.logger.Error("Failed to report gRPC serve error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
close(r.serveErr)
|
||||
}()
|
||||
|
||||
r.logger.Info("gRPC server started", zap.String("network", r.listener.Addr().Network()), zap.String("address", r.listener.Addr().String()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) Finish(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errNilContext
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
started := r.started
|
||||
r.mu.RUnlock()
|
||||
if !started {
|
||||
return nil
|
||||
}
|
||||
|
||||
if r.healthSrv != nil {
|
||||
r.healthSrv.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.server.GracefulStop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
r.logger.Warn("Graceful stop timed out, forcing stop", zap.Error(ctx.Err()))
|
||||
r.server.Stop()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if err, ok := <-r.serveErr; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) Addr() net.Addr {
|
||||
return r.listener.Addr()
|
||||
}
|
||||
|
||||
func (r *Router) Done() <-chan error {
|
||||
return r.serveErr
|
||||
}
|
||||
|
||||
func configureTLS(cfg *TLSConfig) (credentials.TransportCredentials, error) {
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if cfg.CertFile == "" || cfg.KeyFile == "" {
|
||||
return nil, errTLSMissingCertAndKey
|
||||
}
|
||||
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
if err != nil {
|
||||
return nil, newRouterErrorWithCause(errMsgLoadServerCertificate, err)
|
||||
}
|
||||
|
||||
tlsCfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
if cfg.CAFile != "" {
|
||||
caPem, err := os.ReadFile(cfg.CAFile)
|
||||
if err != nil {
|
||||
return nil, newRouterErrorWithCause(errMsgReadCAFile, err)
|
||||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
if ok := certPool.AppendCertsFromPEM(caPem); !ok {
|
||||
return nil, errAppendCACertificates
|
||||
}
|
||||
tlsCfg.ClientCAs = certPool
|
||||
if cfg.RequireClientCert {
|
||||
tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
} else if cfg.RequireClientCert {
|
||||
return nil, errClientCertRequiresCAFile
|
||||
}
|
||||
|
||||
return credentials.NewTLS(tlsCfg), nil
|
||||
}
|
||||
150
api/pkg/api/routers/internal/grpcimp/router_test.go
Normal file
150
api/pkg/api/routers/internal/grpcimp/router_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package grpcimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const bufconnSize = 1024 * 1024
|
||||
|
||||
func newBufferedListener(t *testing.T) *bufconn.Listener {
|
||||
t.Helper()
|
||||
|
||||
listener := bufconn.Listen(bufconnSize)
|
||||
t.Cleanup(func() {
|
||||
listener.Close()
|
||||
})
|
||||
|
||||
return listener
|
||||
}
|
||||
|
||||
func newTestRouter(t *testing.T, cfg *Config) *Router {
|
||||
t.Helper()
|
||||
|
||||
logger := zap.NewNop()
|
||||
if cfg == nil {
|
||||
cfg = &Config{}
|
||||
}
|
||||
|
||||
router, err := NewRouter(logger, cfg, &Options{Listener: newBufferedListener(t)})
|
||||
require.NoError(t, err)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestRouterStartAndFinish(t *testing.T) {
|
||||
router := newTestRouter(t, &Config{})
|
||||
|
||||
doneCh := router.Done()
|
||||
require.NotNil(t, doneCh)
|
||||
|
||||
require.NoError(t, router.Register(func(grpc.ServiceRegistrar) {}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, router.Start(ctx))
|
||||
|
||||
addr := router.Addr()
|
||||
require.NotNil(t, addr)
|
||||
require.NotEmpty(t, addr.String())
|
||||
|
||||
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer finishCancel()
|
||||
|
||||
require.NoError(t, router.Finish(finishCtx))
|
||||
|
||||
select {
|
||||
case err, ok := <-doneCh:
|
||||
if ok {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for done channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterRejectsRegistrationAfterStart(t *testing.T) {
|
||||
router := newTestRouter(t, &Config{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, router.Start(ctx))
|
||||
|
||||
doneCh := router.Done()
|
||||
|
||||
err := router.Register(func(grpc.ServiceRegistrar) {})
|
||||
require.ErrorIs(t, err, errAlreadyStarted)
|
||||
|
||||
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer finishCancel()
|
||||
|
||||
require.NoError(t, router.Finish(finishCtx))
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for done channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterStartOnlyOnce(t *testing.T) {
|
||||
router := newTestRouter(t, &Config{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, router.Start(ctx))
|
||||
require.ErrorIs(t, router.Start(ctx), errAlreadyStarted)
|
||||
|
||||
doneCh := router.Done()
|
||||
|
||||
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer finishCancel()
|
||||
|
||||
require.NoError(t, router.Finish(finishCtx))
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for done channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterUsesProvidedListener(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
listener := newBufferedListener(t)
|
||||
|
||||
cfg := &Config{}
|
||||
router, err := NewRouter(logger, cfg, &Options{Listener: listener})
|
||||
require.NoError(t, err)
|
||||
|
||||
actualListener, ok := router.listener.(*bufconn.Listener)
|
||||
require.True(t, ok)
|
||||
require.Same(t, listener, actualListener)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, router.Start(ctx))
|
||||
|
||||
doneCh := router.Done()
|
||||
|
||||
finishCtx, finishCancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer finishCancel()
|
||||
|
||||
require.NoError(t, router.Finish(finishCtx))
|
||||
|
||||
select {
|
||||
case <-doneCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for done channel")
|
||||
}
|
||||
}
|
||||
45
api/pkg/api/routers/internal/healthimp/health.go
Normal file
45
api/pkg/api/routers/internal/healthimp/health.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package healthimp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/tech/sendico/pkg/api/routers/health"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Router struct {
|
||||
logger mlogger.Logger
|
||||
status *Status
|
||||
}
|
||||
|
||||
func (hr *Router) SetStatus(status health.ServiceStatus) {
|
||||
hr.status.setStatus(status)
|
||||
hr.logger.Info("New status set", zap.String("status", string(status)))
|
||||
}
|
||||
|
||||
func (hr *Router) Finish() {
|
||||
hr.status.Finish()
|
||||
hr.logger.Debug("Stopped")
|
||||
}
|
||||
|
||||
func (hr *Router) handle(w http.ResponseWriter, r *http.Request) {
|
||||
hr.status.healthHandler()(w, r)
|
||||
}
|
||||
|
||||
func NewRouter(logger mlogger.Logger, router chi.Router, endpoint string) *Router {
|
||||
hr := Router{
|
||||
logger: logger.Named("health_check"),
|
||||
}
|
||||
hr.status = StatusHandler(hr.logger)
|
||||
|
||||
logger.Debug("Installing healthcheck middleware...")
|
||||
router.Group(func(r chi.Router) {
|
||||
ep := endpoint + "/health"
|
||||
r.Get(ep, hr.handle)
|
||||
logger.Info("Health handler installed", zap.String("endpoint", ep))
|
||||
})
|
||||
|
||||
return &hr
|
||||
}
|
||||
38
api/pkg/api/routers/internal/healthimp/status.go
Normal file
38
api/pkg/api/routers/internal/healthimp/status.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package healthimp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/tech/sendico/pkg/api/http/response"
|
||||
"github.com/tech/sendico/pkg/api/routers/health"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
type Status struct {
|
||||
logger mlogger.Logger
|
||||
status health.ServiceStatus
|
||||
}
|
||||
|
||||
func (hs *Status) healthHandler() http.HandlerFunc {
|
||||
return response.Ok(hs.logger, struct {
|
||||
Status health.ServiceStatus `json:"status"`
|
||||
}{
|
||||
hs.status,
|
||||
})
|
||||
}
|
||||
|
||||
func (hr *Status) Finish() {
|
||||
hr.logger.Info("Finished")
|
||||
}
|
||||
|
||||
func (hs *Status) setStatus(status health.ServiceStatus) {
|
||||
hs.status = status
|
||||
}
|
||||
|
||||
func StatusHandler(logger mlogger.Logger) *Status {
|
||||
hs := Status{
|
||||
status: health.SSCreated,
|
||||
logger: logger.Named("status"),
|
||||
}
|
||||
return &hs
|
||||
}
|
||||
66
api/pkg/api/routers/internal/messagingimp/consumer.go
Normal file
66
api/pkg/api/routers/internal/messagingimp/consumer.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package messagingimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/messaging"
|
||||
mb "github.com/tech/sendico/pkg/messaging/broker"
|
||||
me "github.com/tech/sendico/pkg/messaging/envelope"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ChannelConsumer struct {
|
||||
logger mlogger.Logger
|
||||
broker mb.Broker
|
||||
event model.NotificationEvent
|
||||
ch <-chan me.Envelope
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *ChannelConsumer) ConsumeMessages(handleFunc messaging.MessageHandlerT) error {
|
||||
c.logger.Info("Message consumer is ready")
|
||||
for {
|
||||
select {
|
||||
case msg := <-c.ch:
|
||||
if msg == nil { // nil message indicates the channel was closed
|
||||
c.logger.Info("Consumer shutting down")
|
||||
return nil
|
||||
}
|
||||
if err := handleFunc(c.ctx, msg); err != nil {
|
||||
c.logger.Warn("Error processing message", zap.Error(err))
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
c.logger.Info("Context done, shutting down")
|
||||
return c.ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelConsumer) Close() {
|
||||
c.logger.Info("Shutting down...")
|
||||
c.cancel()
|
||||
if err := c.broker.Unsubscribe(c.event, c.ch); err != nil {
|
||||
c.logger.Warn("Failed to unsubscribe", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func NewConsumer(logger mlogger.Logger, broker mb.Broker, event model.NotificationEvent) (*ChannelConsumer, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ch, err := broker.Subscribe(event)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to create channel consumer", zap.Error(err), zap.String("topic", event.ToString()))
|
||||
cancel() // Ensure resources are released properly
|
||||
return nil, err
|
||||
}
|
||||
return &ChannelConsumer{
|
||||
logger: logger.Named("consumer").Named(event.ToString()),
|
||||
broker: broker,
|
||||
event: event,
|
||||
ch: ch,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
67
api/pkg/api/routers/internal/messagingimp/messsaging.go
Normal file
67
api/pkg/api/routers/internal/messagingimp/messsaging.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package messagingimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/messaging"
|
||||
mb "github.com/tech/sendico/pkg/messaging/broker"
|
||||
notifications "github.com/tech/sendico/pkg/messaging/notifications/processor"
|
||||
mip "github.com/tech/sendico/pkg/messaging/producer"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type MessagingRouter struct {
|
||||
logger mlogger.Logger
|
||||
messaging mb.Broker
|
||||
consumers []messaging.Consumer
|
||||
producer messaging.Producer
|
||||
}
|
||||
|
||||
func (mr *MessagingRouter) consumeMessages(c messaging.Consumer, processor notifications.EnvelopeProcessor) {
|
||||
if err := c.ConsumeMessages(processor.Process); err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
mr.logger.Warn("Error consuming messages", zap.Error(err), zap.String("event", processor.GetSubject().ToString()))
|
||||
} else {
|
||||
mr.logger.Info("Finishing as context has been cancelled", zap.String("event", processor.GetSubject().ToString()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *MessagingRouter) Consumer(processor notifications.EnvelopeProcessor) error {
|
||||
c, err := NewConsumer(mr.logger, mr.messaging, processor.GetSubject())
|
||||
if err != nil {
|
||||
mr.logger.Warn("Failed to register message consumer", zap.Error(err), zap.String("event", processor.GetSubject().ToString()))
|
||||
return err
|
||||
}
|
||||
mr.consumers = append(mr.consumers, c)
|
||||
go mr.consumeMessages(c, processor)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *MessagingRouter) Finish() {
|
||||
mr.logger.Info("Closing consumer channels")
|
||||
for _, consumer := range mr.consumers {
|
||||
consumer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *MessagingRouter) Producer() messaging.Producer {
|
||||
return mr.producer
|
||||
}
|
||||
|
||||
func NewMessagingRouterImp(logger mlogger.Logger, config *messaging.Config) (*MessagingRouter, error) {
|
||||
l := logger.Named("messaging")
|
||||
broker, err := messaging.CreateMessagingBroker(l, config)
|
||||
if err != nil {
|
||||
l.Error("Failed to create messaging broker", zap.Error(err), zap.String("broker", string(config.Driver)))
|
||||
return nil, err
|
||||
}
|
||||
return &MessagingRouter{
|
||||
logger: l,
|
||||
messaging: broker,
|
||||
producer: mip.NewProducer(logger, broker),
|
||||
consumers: make([]messaging.Consumer, 0),
|
||||
}, nil
|
||||
}
|
||||
16
api/pkg/api/routers/messaging.go
Normal file
16
api/pkg/api/routers/messaging.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package routers
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/api/routers/internal/messagingimp"
|
||||
"github.com/tech/sendico/pkg/messaging"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
type Messaging interface {
|
||||
messaging.Register
|
||||
Finish()
|
||||
}
|
||||
|
||||
func NewMessagingRouter(logger mlogger.Logger, config *messaging.Config) (Messaging, error) {
|
||||
return messagingimp.NewMessagingRouterImp(logger, config)
|
||||
}
|
||||
202
api/pkg/auth/USAGE.md
Normal file
202
api/pkg/auth/USAGE.md
Normal file
@@ -0,0 +1,202 @@
|
||||
# Auth.Indexable Usage Guide
|
||||
|
||||
## Secure Reordering with Permission Checking
|
||||
|
||||
The `auth.Indexable` implementation adds **permission checking** to the generic reordering functionality using `EnforceBatch`.
|
||||
|
||||
- **Core Implementation**: `api/pkg/auth/indexable.go` - generic implementation with permission checking
|
||||
- **Project Factory**: `api/pkg/auth/project_indexable.go` - convenient factory for projects
|
||||
- **Key Feature**: Uses `EnforceBatch` to check permissions for all affected objects
|
||||
|
||||
## How It Works
|
||||
|
||||
### Permission Checking Flow
|
||||
1. **Get current object** to find its index
|
||||
2. **Determine affected objects** that will be shifted during reordering
|
||||
3. **Check permissions** using `EnforceBatch` for all affected objects + target object
|
||||
4. **Verify all permissions** - if any object lacks update permission, return error
|
||||
5. **Proceed with reordering** only if all permissions are granted
|
||||
|
||||
### Key Differences from Basic Indexable
|
||||
- **Additional parameter**: `accountRef` for permission checking
|
||||
- **Permission validation**: All affected objects must have `ActionUpdate` permission
|
||||
- **Security**: Prevents unauthorized reordering that could affect other users' data
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Using the Generic Auth.Indexable Implementation
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/auth"
|
||||
|
||||
// For any type that embeds model.Indexable, define helper functions:
|
||||
createEmpty := func() *YourType {
|
||||
return &YourType{}
|
||||
}
|
||||
|
||||
getIndexable := func(obj *YourType) *model.Indexable {
|
||||
return &obj.Indexable
|
||||
}
|
||||
|
||||
// Create auth.IndexableDB with enforcer
|
||||
indexableDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
|
||||
|
||||
// Use with account reference for permission checking
|
||||
err := indexableDB.Reorder(ctx, accountRef, objectID, newIndex, filter)
|
||||
```
|
||||
|
||||
### 2. Using the Project Factory (Recommended for Projects)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/auth"
|
||||
|
||||
// Create auth.ProjectIndexableDB (automatically applies org filter)
|
||||
projectDB := auth.NewProjectIndexableDB(repo, logger, enforcer, organizationRef)
|
||||
|
||||
// Reorder project with permission checking
|
||||
err := projectDB.Reorder(ctx, accountRef, projectID, newIndex, repository.Query())
|
||||
|
||||
// Reorder with additional filters (combined with org filter)
|
||||
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
|
||||
err := projectDB.Reorder(ctx, accountRef, projectID, newIndex, additionalFilter)
|
||||
```
|
||||
|
||||
## Examples for Different Types
|
||||
|
||||
### Project Auth.IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
projectDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(ctx, accountRef, projectID, 2, orgFilter)
|
||||
```
|
||||
|
||||
### Status Auth.IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Status {
|
||||
return &model.Status{}
|
||||
}
|
||||
|
||||
getIndexable := func(s *model.Status) *model.Indexable {
|
||||
return &s.Indexable
|
||||
}
|
||||
|
||||
statusDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
|
||||
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
|
||||
statusDB.Reorder(ctx, accountRef, statusID, 1, projectFilter)
|
||||
```
|
||||
|
||||
### Task Auth.IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
taskDB := auth.NewIndexableDB(repo, logger, enforcer, createEmpty, getIndexable)
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(ctx, accountRef, taskID, 3, statusFilter)
|
||||
```
|
||||
|
||||
## Permission Checking Details
|
||||
|
||||
### What Gets Checked
|
||||
When reordering an object from index `A` to index `B`:
|
||||
|
||||
1. **Target object** - the object being moved
|
||||
2. **Affected objects** - all objects whose indices will be shifted:
|
||||
- Moving down: objects between `A+1` and `B` (shifted up by -1)
|
||||
- Moving up: objects between `B` and `A-1` (shifted down by +1)
|
||||
|
||||
### Permission Requirements
|
||||
- **Action**: `model.ActionUpdate`
|
||||
- **Scope**: All affected objects must be `PermissionBoundStorable`
|
||||
- **Result**: If any object lacks permission, the entire operation fails
|
||||
|
||||
### Error Handling
|
||||
```go
|
||||
// Permission denied error
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "accessDenied") {
|
||||
// Handle permission denied
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Benefits
|
||||
|
||||
### ✅ **Comprehensive Permission Checking**
|
||||
- Checks permissions for **all affected objects**, not just the target
|
||||
- Prevents unauthorized reordering that could affect other users' data
|
||||
- Uses efficient `EnforceBatch` for bulk permission checking
|
||||
|
||||
### ✅ **Type Safety**
|
||||
- Generic implementation works with any `Indexable` struct
|
||||
- Compile-time type checking
|
||||
- No runtime type assertions
|
||||
|
||||
### ✅ **Flexible Filtering**
|
||||
- Single `builder.Query` parameter for scoping
|
||||
- Can combine organization filters with additional criteria
|
||||
- Project factory automatically applies organization filtering
|
||||
|
||||
### ✅ **Clean Architecture**
|
||||
- Separates permission logic from reordering logic
|
||||
- Easy to test with mock enforcers
|
||||
- Follows existing auth patterns
|
||||
|
||||
## Testing
|
||||
|
||||
### Mock Enforcer Setup
|
||||
```go
|
||||
mockEnforcer := &MockEnforcer{}
|
||||
|
||||
// Grant all permissions
|
||||
permissions := map[primitive.ObjectID]bool{
|
||||
objectID1: true,
|
||||
objectID2: true,
|
||||
}
|
||||
mockEnforcer.On("EnforceBatch", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(permissions, nil)
|
||||
|
||||
// Deny specific permission
|
||||
permissions[objectID2] = false
|
||||
mockEnforcer.On("EnforceBatch", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(permissions, nil)
|
||||
```
|
||||
|
||||
### Test Scenarios
|
||||
- ✅ **Permission granted** - reordering succeeds
|
||||
- ❌ **Permission denied** - reordering fails with access denied error
|
||||
- 🔄 **No change needed** - early return, minimal permission checking
|
||||
- 🏢 **Organization filtering** - automatic org scope for projects
|
||||
|
||||
## Comparison: Basic vs Auth.Indexable
|
||||
|
||||
| Feature | Basic Indexable | Auth.Indexable |
|
||||
|---------|----------------|----------------|
|
||||
| Permission checking | ❌ No | ✅ Yes |
|
||||
| Account parameter | ❌ No | ✅ Required |
|
||||
| Security | ❌ None | ✅ Comprehensive |
|
||||
| Performance | ✅ Fast | ⚠️ Slower (permission checks) |
|
||||
| Use case | Internal operations | User-facing operations |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Auth.Indexable** for user-facing reordering operations
|
||||
2. **Use Basic Indexable** for internal/system operations
|
||||
3. **Always provide account reference** for proper permission checking
|
||||
4. **Test permission scenarios** thoroughly with mock enforcers
|
||||
5. **Handle permission errors** gracefully in user interfaces
|
||||
|
||||
That's it! **Secure, type-safe reordering** with comprehensive permission checking using `EnforceBatch`.
|
||||
3
api/pkg/auth/anyobject/anyobject.go
Normal file
3
api/pkg/auth/anyobject/anyobject.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package anyobject
|
||||
|
||||
const ID = "*"
|
||||
35
api/pkg/auth/archivable.go
Normal file
35
api/pkg/auth/archivable.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// ArchivableDB implements archive operations with permission checking
|
||||
type ArchivableDB[T model.PermissionBoundStorable] interface {
|
||||
// SetArchived sets the archived status of an entity with permission checking
|
||||
SetArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID, archived bool) error
|
||||
// IsArchived checks if an entity is archived with permission checking
|
||||
IsArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID) (bool, error)
|
||||
|
||||
// Archive archives an entity with permission checking (sets archived to true)
|
||||
Archive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
|
||||
// Unarchive unarchives an entity with permission checking (sets archived to false)
|
||||
Unarchive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
}
|
||||
|
||||
// NewArchivableDB creates a new auth.ArchivableDB instance
|
||||
func NewArchivableDB[T model.PermissionBoundStorable](
|
||||
dbImp *template.DBImp[T],
|
||||
logger mlogger.Logger,
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getArchivable func(T) model.Archivable,
|
||||
) ArchivableDB[T] {
|
||||
return newArchivableDBImp(dbImp, logger, enforcer, createEmpty, getArchivable)
|
||||
}
|
||||
107
api/pkg/auth/archivableimp.go
Normal file
107
api/pkg/auth/archivableimp.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArchivableDB implements archive operations with permission checking
|
||||
type ArchivableDBImp[T model.PermissionBoundStorable] struct {
|
||||
dbImp *template.DBImp[T]
|
||||
logger mlogger.Logger
|
||||
enforcer Enforcer
|
||||
createEmpty func() T
|
||||
getArchivable func(T) model.Archivable
|
||||
}
|
||||
|
||||
// NewArchivableDB creates a new auth.ArchivableDB instance
|
||||
func newArchivableDBImp[T model.PermissionBoundStorable](
|
||||
dbImp *template.DBImp[T],
|
||||
logger mlogger.Logger,
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getArchivable func(T) model.Archivable,
|
||||
) ArchivableDB[T] {
|
||||
return &ArchivableDBImp[T]{
|
||||
dbImp: dbImp,
|
||||
logger: logger.Named("archivable"),
|
||||
enforcer: enforcer,
|
||||
createEmpty: createEmpty,
|
||||
getArchivable: getArchivable,
|
||||
}
|
||||
}
|
||||
|
||||
// SetArchived sets the archived status of an entity with permission checking
|
||||
func (db *ArchivableDBImp[T]) SetArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID, archived bool) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
db.logger.Warn("Failed to enforce object permission", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the object to check current archived status
|
||||
obj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for setting archived status", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract archivable from the object
|
||||
archivable := db.getArchivable(obj)
|
||||
currentArchived := archivable.IsArchived()
|
||||
if currentArchived == archived {
|
||||
db.logger.Debug("No change needed - same archived status", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Set the archived status
|
||||
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
|
||||
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to set archived status on object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully set archived status on object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsArchived checks if an entity is archived with permission checking
|
||||
func (db *ArchivableDBImp[T]) IsArchived(ctx context.Context, accountRef, objectRef primitive.ObjectID) (bool, error) {
|
||||
// // Check permissions using single Enforce
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
|
||||
db.logger.Debug("Permission denied for checking archived status", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionRead)))
|
||||
return false, merrors.AccessDenied("read", "object", objectRef)
|
||||
}
|
||||
obj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for checking archived status", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return false, err
|
||||
}
|
||||
archivable := db.getArchivable(obj)
|
||||
return archivable.IsArchived(), nil
|
||||
}
|
||||
|
||||
// Archive archives an entity with permission checking (sets archived to true)
|
||||
func (db *ArchivableDBImp[T]) Archive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, accountRef, objectRef, true)
|
||||
}
|
||||
|
||||
// Unarchive unarchives an entity with permission checking (sets archived to false)
|
||||
func (db *ArchivableDBImp[T]) Unarchive(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, accountRef, objectRef, false)
|
||||
}
|
||||
12
api/pkg/auth/config.go
Normal file
12
api/pkg/auth/config.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package auth
|
||||
|
||||
import "github.com/tech/sendico/pkg/model"
|
||||
|
||||
type EnforcerType string
|
||||
|
||||
const (
|
||||
Casbin EnforcerType = "casbin"
|
||||
Native EnforcerType = "native"
|
||||
)
|
||||
|
||||
type Config = model.DriverConfig[EnforcerType]
|
||||
8
api/pkg/auth/customizable/customizable.go
Normal file
8
api/pkg/auth/customizable/customizable.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package customizable
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
type DB[T model.PermissionBoundStorable] interface {
|
||||
}
|
||||
8
api/pkg/auth/customizable/manager.go
Normal file
8
api/pkg/auth/customizable/manager.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package customizable
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
type Manager[T model.PermissionBoundStorable] interface {
|
||||
}
|
||||
38
api/pkg/auth/db.go
Normal file
38
api/pkg/auth/db.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type ProtectedDB[T model.PermissionBoundStorable] interface {
|
||||
Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error
|
||||
InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error
|
||||
Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error
|
||||
Update(ctx context.Context, accountRef primitive.ObjectID, object T) error
|
||||
Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error
|
||||
PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error)
|
||||
Unprotected() template.DB[T]
|
||||
ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error)
|
||||
}
|
||||
|
||||
func CreateDB[T model.PermissionBoundStorable](
|
||||
ctx context.Context,
|
||||
l mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
enforcer Enforcer,
|
||||
collection mservice.Type,
|
||||
db *mongo.Database,
|
||||
) (ProtectedDB[T], error) {
|
||||
return CreateDBImp[T](ctx, l, pdb, enforcer, collection, db)
|
||||
}
|
||||
51
api/pkg/auth/dbab.go
Normal file
51
api/pkg/auth/dbab.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AccountBoundDB[T model.AccountBoundStorable] interface {
|
||||
Create(ctx context.Context, accountRef primitive.ObjectID, object T) error
|
||||
Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error
|
||||
Update(ctx context.Context, accountRef primitive.ObjectID, object T) error
|
||||
Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error
|
||||
Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
DeleteMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) error
|
||||
FindOne(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, result T) error
|
||||
ListIDs(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error)
|
||||
ListAccountBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID, query builder.Query) ([]model.AccountBoundStorable, error)
|
||||
}
|
||||
|
||||
func CreateAccountBound[T model.AccountBoundStorable](
|
||||
ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
enforcer Enforcer,
|
||||
collection mservice.Type,
|
||||
db *mongo.Database,
|
||||
) (AccountBoundDB[T], error) {
|
||||
logger = logger.Named("account_bound")
|
||||
var policy model.PolicyDescription
|
||||
if err := pdb.GetBuiltInPolicy(ctx, mservice.Organizations, &policy); err != nil {
|
||||
logger.Warn("Failed to fetch organization policy description", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
res := &AccountBoundDBImp[T]{
|
||||
Logger: logger,
|
||||
DBImp: template.Create[T](logger, collection, db),
|
||||
Enforcer: enforcer,
|
||||
PermissionRef: policy.ID,
|
||||
Collection: collection,
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
319
api/pkg/auth/dbimp.go
Normal file
319
api/pkg/auth/dbimp.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ProtectedDBImp[T model.PermissionBoundStorable] struct {
|
||||
DBImp *template.DBImp[T]
|
||||
Enforcer Enforcer
|
||||
PermissionRef primitive.ObjectID
|
||||
Collection mservice.Type
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) enforce(ctx context.Context, action model.Action, object model.PermissionBoundStorable, accountRef, objectRef primitive.ObjectID) error {
|
||||
res, err := db.Enforcer.Enforce(ctx, object.GetPermissionRef(), accountRef, object.GetOrganizationRef(), objectRef, action)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to enforce permission",
|
||||
zap.Error(err), mzap.ObjRef("permission_ref", object.GetPermissionRef()),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.DBImp.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", object.GetPermissionRef()),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return merrors.AccessDenied(db.Collection, string(action), objectRef)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, object T) error {
|
||||
db.DBImp.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
if object.GetPermissionRef() == primitive.NilObjectID {
|
||||
object.SetPermissionRef(db.PermissionRef)
|
||||
}
|
||||
object.SetOrganizationRef(organizationRef)
|
||||
|
||||
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Create(ctx, object); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []T) error {
|
||||
if len(objects) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Attempting to insert many objects", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
||||
zap.Int("count", len(objects)))
|
||||
|
||||
// Set permission and organization refs for all objects and enforce permissions
|
||||
for _, object := range objects {
|
||||
if object.GetPermissionRef() == primitive.NilObjectID {
|
||||
object.SetPermissionRef(db.PermissionRef)
|
||||
}
|
||||
object.SetOrganizationRef(organizationRef)
|
||||
|
||||
if err := db.enforce(ctx, model.ActionCreate, object, accountRef, primitive.NilObjectID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.DBImp.InsertMany(ctx, objects); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to insert many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
||||
zap.Int("count", len(objects)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully inserted many objects", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), zap.String("collection", string(db.Collection)),
|
||||
zap.Int("count", len(objects)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) enforceObject(ctx context.Context, action model.Action, accountRef, objectRef primitive.ObjectID) error {
|
||||
l, err := db.ListIDs(ctx, action, accountRef, repository.IDFilter(objectRef))
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Error occured while checking access rights", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return err
|
||||
}
|
||||
if len(l) == 0 {
|
||||
db.DBImp.Logger.Debug("Access denied", zap.String("action", string(action)), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return merrors.AccessDenied(db.Collection, string(action), objectRef)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error {
|
||||
db.DBImp.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
if err := db.enforceObject(ctx, model.ActionRead, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Get(ctx, objectRef, result); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully retrieved object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", result.GetOrganizationRef()),
|
||||
mzap.StorableRef(result), mzap.ObjRef("permission_ref", result.GetPermissionRef()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error {
|
||||
db.DBImp.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object))
|
||||
|
||||
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, *object.GetID()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Update(ctx, object); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully updated object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", object.GetOrganizationRef()),
|
||||
mzap.StorableRef(object), mzap.ObjRef("permission_ref", object.GetPermissionRef()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Attempting to delete object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Delete(ctx, objectRef); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to delete object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) ListIDs(
|
||||
ctx context.Context,
|
||||
action model.Action,
|
||||
accountRef primitive.ObjectID,
|
||||
query builder.Query,
|
||||
) ([]primitive.ObjectID, error) {
|
||||
db.DBImp.Logger.Debug("Attempting to list object IDs",
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
|
||||
|
||||
// 1. Fetch all candidate IDs from the underlying DB
|
||||
allIDs, err := db.DBImp.ListPermissionBound(ctx, query)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to list object IDs", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
|
||||
return nil, err
|
||||
}
|
||||
if len(allIDs) == 0 {
|
||||
db.DBImp.Logger.Debug("No objects found matching filter", mzap.ObjRef("account_ref", accountRef),
|
||||
zap.String("collection", string(db.Collection)), zap.Any("filter", query.BuildQuery()))
|
||||
return []primitive.ObjectID{}, merrors.NoData(fmt.Sprintf("no %s found", db.Collection))
|
||||
}
|
||||
|
||||
// 2. Check read permission for each ID
|
||||
var allowedIDs []primitive.ObjectID
|
||||
for _, desc := range allIDs {
|
||||
enforceErr := db.enforce(ctx, action, desc, accountRef, *desc.GetID())
|
||||
if enforceErr == nil {
|
||||
allowedIDs = append(allowedIDs, *desc.GetID())
|
||||
} else if !errors.Is(err, merrors.ErrAccessDenied) {
|
||||
// If the error is something other than AccessDenied, we want to fail
|
||||
db.DBImp.Logger.Warn("Error while enforcing read permission", zap.Error(enforceErr),
|
||||
mzap.ObjRef("permission_ref", desc.GetPermissionRef()), zap.String("action", string(action)),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
|
||||
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("collection", string(db.Collection)),
|
||||
)
|
||||
return nil, enforceErr
|
||||
}
|
||||
// If AccessDenied, we simply skip that ID.
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully enforced read permission on IDs", zap.Int("fetched_count", len(allIDs)),
|
||||
zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef),
|
||||
zap.String("collection", string(db.Collection)), zap.String("action", string(action)))
|
||||
|
||||
// 3. Return only the IDs that passed permission checks
|
||||
return allowedIDs, nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Unprotected() template.DB[T] {
|
||||
return db.DBImp
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) DeleteCascadeAuth(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
if err := db.enforceObject(ctx, model.ActionDelete, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := db.DBImp.DeleteCascade(ctx, objectRef); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to delete dependent object", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateDBImp[T model.PermissionBoundStorable](
|
||||
ctx context.Context,
|
||||
l mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
enforcer Enforcer,
|
||||
collection mservice.Type,
|
||||
db *mongo.Database,
|
||||
) (*ProtectedDBImp[T], error) {
|
||||
logger := l.Named("protected")
|
||||
var policy model.PolicyDescription
|
||||
if err := pdb.GetBuiltInPolicy(ctx, collection, &policy); err != nil {
|
||||
logger.Warn("Failed to fetch policy description", zap.Error(err), zap.String("resource_type", string(collection)))
|
||||
return nil, err
|
||||
}
|
||||
p := &ProtectedDBImp[T]{
|
||||
DBImp: template.Create[T](logger, collection, db),
|
||||
PermissionRef: policy.ID,
|
||||
Collection: collection,
|
||||
Enforcer: enforcer,
|
||||
}
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: storable.OrganizationRefField, Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
logger.Warn("Failed to create index", zap.Error(err), zap.String("resource_type", string(collection)))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
db.DBImp.Logger.Debug("Attempting to patch object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
if err := db.enforceObject(ctx, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Repository.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to patch object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully patched object",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *ProtectedDBImp[T]) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
||||
db.DBImp.Logger.Debug("Attempting to patch many objects",
|
||||
mzap.ObjRef("account_ref", accountRef), zap.Any("filter", query.BuildQuery()))
|
||||
|
||||
ids, err := db.ListIDs(ctx, model.ActionUpdate, accountRef, query)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
values := make([]any, len(ids))
|
||||
for i, id := range ids {
|
||||
values[i] = id
|
||||
}
|
||||
idFilter := repository.Query().In(repository.IDField(), values...)
|
||||
finalQuery := query.And(idFilter)
|
||||
|
||||
modified, err := db.DBImp.Repository.PatchMany(ctx, finalQuery, patch)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to patch many objects", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully patched many objects",
|
||||
mzap.ObjRef("account_ref", accountRef), zap.Int("modified_count", modified))
|
||||
return modified, nil
|
||||
}
|
||||
420
api/pkg/auth/dbimpab.go
Normal file
420
api/pkg/auth/dbimpab.go
Normal file
@@ -0,0 +1,420 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AccountBoundDBImp[T model.AccountBoundStorable] struct {
|
||||
Logger mlogger.Logger
|
||||
DBImp *template.DBImp[T]
|
||||
Enforcer Enforcer
|
||||
PermissionRef primitive.ObjectID
|
||||
Collection mservice.Type
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) enforce(ctx context.Context, action model.Action, object model.AccountBoundStorable, accountRef primitive.ObjectID) error {
|
||||
// FIRST: Check if the object's AccountRef equals the calling accountRef - if so, ALLOW
|
||||
objectAccountRef := object.GetAccountRef()
|
||||
if objectAccountRef != nil && *objectAccountRef == accountRef {
|
||||
db.Logger.Debug("Access granted - object belongs to calling account",
|
||||
mzap.ObjRef("object_account_ref", *objectAccountRef),
|
||||
mzap.ObjRef("calling_account_ref", accountRef),
|
||||
zap.String("action", string(action)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SECOND: If not owned by calling account, check organization-level permissions
|
||||
organizationRef := object.GetOrganizationRef()
|
||||
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, organizationRef, action)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to enforce permission",
|
||||
zap.Error(err), mzap.ObjRef("permission_ref", db.PermissionRef),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
|
||||
zap.String("action", string(action)))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", db.PermissionRef),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
|
||||
zap.String("action", string(action)))
|
||||
return merrors.AccessDenied(db.Collection, string(action), primitive.NilObjectID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) enforceInterface(ctx context.Context, action model.Action, object model.AccountBoundStorable, accountRef primitive.ObjectID) error {
|
||||
// FIRST: Check if the object's AccountRef equals the calling accountRef - if so, ALLOW
|
||||
objectAccountRef := object.GetAccountRef()
|
||||
if objectAccountRef != nil && *objectAccountRef == accountRef {
|
||||
db.Logger.Debug("Access granted - object belongs to calling account",
|
||||
mzap.ObjRef("object_account_ref", *objectAccountRef),
|
||||
mzap.ObjRef("calling_account_ref", accountRef),
|
||||
zap.String("action", string(action)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SECOND: If not owned by calling account, check organization-level permissions
|
||||
organizationRef := object.GetOrganizationRef()
|
||||
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, organizationRef, action)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to enforce permission",
|
||||
zap.Error(err), mzap.ObjRef("permission_ref", db.PermissionRef),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
|
||||
zap.String("action", string(action)))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.Logger.Debug("Access denied", mzap.ObjRef("permission_ref", db.PermissionRef),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
|
||||
zap.String("action", string(action)))
|
||||
return merrors.AccessDenied(db.Collection, string(action), primitive.NilObjectID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) Create(ctx context.Context, accountRef primitive.ObjectID, object T) error {
|
||||
orgRef := object.GetOrganizationRef()
|
||||
db.Logger.Debug("Attempting to create object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
// Check organization update permission for create operations
|
||||
if err := db.enforce(ctx, model.ActionUpdate, object, accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Create(ctx, object); err != nil {
|
||||
db.Logger.Warn("Failed to create object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully created object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", orgRef), zap.String("collection", string(db.Collection)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) Get(ctx context.Context, accountRef, objectRef primitive.ObjectID, result T) error {
|
||||
db.Logger.Debug("Attempting to get object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
// First get the object to check its organization
|
||||
if err := db.DBImp.Get(ctx, objectRef, result); err != nil {
|
||||
db.Logger.Warn("Failed to get object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check organization read permission
|
||||
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully retrieved object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", result.GetOrganizationRef()), zap.String("collection", string(db.Collection)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) Update(ctx context.Context, accountRef primitive.ObjectID, object T) error {
|
||||
db.Logger.Debug("Attempting to update object", mzap.ObjRef("account_ref", accountRef), mzap.StorableRef(object))
|
||||
|
||||
// Check organization update permission
|
||||
if err := db.enforce(ctx, model.ActionUpdate, object, accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Update(ctx, object); err != nil {
|
||||
db.Logger.Warn("Failed to update object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully updated object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", object.GetOrganizationRef()), mzap.StorableRef(object))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
db.Logger.Debug("Attempting to patch object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
// First get the object to check its organization
|
||||
objs, err := db.DBImp.Repository.ListAccountBound(ctx, repository.IDFilter(objectRef))
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to get object for permission check when deleting", zap.Error(err), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
if len(objs) == 0 {
|
||||
db.Logger.Debug("Permission denied for deletion", mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("account_ref", accountRef))
|
||||
return merrors.AccessDenied(db.Collection, string(model.ActionDelete), objectRef)
|
||||
}
|
||||
|
||||
// Check organization update permission
|
||||
if err := db.enforce(ctx, model.ActionUpdate, objs[0], accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.Logger.Warn("Failed to patch object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully patched object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) Delete(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
db.Logger.Debug("Attempting to delete object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
|
||||
// First get the object to check its organization
|
||||
objs, err := db.DBImp.Repository.ListAccountBound(ctx, repository.IDFilter(objectRef))
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to get object for permission check when deleting", zap.Error(err), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
if len(objs) == 0 {
|
||||
db.Logger.Debug("Permission denied for deletion", mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("account_ref", accountRef))
|
||||
return merrors.AccessDenied(db.Collection, string(model.ActionDelete), objectRef)
|
||||
}
|
||||
// Check organization update permission for delete operations
|
||||
if err := db.enforce(ctx, model.ActionUpdate, objs[0], accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DBImp.Delete(ctx, objectRef); err != nil {
|
||||
db.Logger.Warn("Failed to delete object", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("collection", string(db.Collection)))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully deleted object", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) DeleteMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) error {
|
||||
db.Logger.Debug("Attempting to delete many objects", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
// Get all candidate objects for batch permission checking
|
||||
allObjects, err := db.DBImp.Repository.ListPermissionBound(ctx, query)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to list objects for delete many", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Use batch enforcement for efficiency
|
||||
allowedResults, err := db.Enforcer.EnforceBatch(ctx, allObjects, accountRef, model.ActionUpdate)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to enforce batch permissions for delete many", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Build query for objects that passed permission check
|
||||
var allowedIDs []primitive.ObjectID
|
||||
for _, obj := range allObjects {
|
||||
if allowedResults[*obj.GetID()] {
|
||||
allowedIDs = append(allowedIDs, *obj.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIDs) == 0 {
|
||||
db.Logger.Debug("No objects allowed for deletion", mzap.ObjRef("account_ref", accountRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete only the allowed objects
|
||||
allowedQuery := query.And(repository.Query().In(repository.IDField(), allowedIDs))
|
||||
if err := db.DBImp.DeleteMany(ctx, allowedQuery); err != nil {
|
||||
db.Logger.Warn("Failed to delete many objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully deleted many objects", mzap.ObjRef("account_ref", accountRef), zap.Int("count", len(allowedIDs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) FindOne(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, result T) error {
|
||||
db.Logger.Debug("Attempting to find one object", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
// For FindOne, we need to check read permission after finding the object
|
||||
if err := db.DBImp.FindOne(ctx, query, result); err != nil {
|
||||
db.Logger.Warn("Failed to find one object", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check organization read permission for the found object
|
||||
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully found one object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", result.GetOrganizationRef()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) ListIDs(ctx context.Context, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
db.Logger.Debug("Attempting to list object IDs", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
// Get all candidate objects for batch permission checking
|
||||
allObjects, err := db.DBImp.Repository.ListPermissionBound(ctx, query)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to list objects for ID filtering", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use batch enforcement for efficiency
|
||||
allowedResults, err := db.Enforcer.EnforceBatch(ctx, allObjects, accountRef, model.ActionRead)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to enforce batch permissions for ID listing", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter to only allowed object IDs
|
||||
var allowedIDs []primitive.ObjectID
|
||||
for _, obj := range allObjects {
|
||||
if allowedResults[*obj.GetID()] {
|
||||
allowedIDs = append(allowedIDs, *obj.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully filtered object IDs", zap.Int("total_count", len(allObjects)),
|
||||
zap.Int("allowed_count", len(allowedIDs)), mzap.ObjRef("account_ref", accountRef))
|
||||
return allowedIDs, nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) ListAccountBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID, query builder.Query) ([]model.AccountBoundStorable, error) {
|
||||
db.Logger.Debug("Attempting to list account bound objects", mzap.ObjRef("account_ref", accountRef), zap.String("collection", string(db.Collection)))
|
||||
|
||||
// Build query to find objects where accountRef matches OR is null/absent
|
||||
accountQuery := repository.WithOrg(accountRef, organizationRef)
|
||||
|
||||
// Combine with the provided query
|
||||
finalQuery := query.And(accountQuery)
|
||||
|
||||
// Get all candidate objects
|
||||
allObjects, err := db.DBImp.Repository.ListAccountBound(ctx, finalQuery)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to list account bound objects", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter objects based on read permissions (AccountBoundStorable doesn't have permission info, so we check organization level)
|
||||
var allowedObjects []model.AccountBoundStorable
|
||||
for _, obj := range allObjects {
|
||||
if err := db.enforceInterface(ctx, model.ActionRead, obj, accountRef); err == nil {
|
||||
allowedObjects = append(allowedObjects, obj)
|
||||
} else if !errors.Is(err, merrors.ErrAccessDenied) {
|
||||
// If the error is something other than AccessDenied, we want to fail
|
||||
db.Logger.Warn("Error while enforcing read permission", zap.Error(err), mzap.ObjRef("object_ref", *obj.GetID()))
|
||||
return nil, err
|
||||
}
|
||||
// If AccessDenied, we simply skip that object
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully filtered account bound objects", zap.Int("total_count", len(allObjects)),
|
||||
zap.Int("allowed_count", len(allowedObjects)), mzap.ObjRef("account_ref", accountRef))
|
||||
return allowedObjects, nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) GetByAccountRef(ctx context.Context, accountRef primitive.ObjectID, result T) error {
|
||||
db.Logger.Debug("Attempting to get object by account ref", mzap.ObjRef("account_ref", accountRef))
|
||||
|
||||
// Build query to find objects where accountRef matches OR is null/absent
|
||||
query := repository.WithoutOrg(accountRef)
|
||||
|
||||
if err := db.DBImp.FindOne(ctx, query, result); err != nil {
|
||||
db.Logger.Warn("Failed to get object by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check organization read permission for the found object
|
||||
if err := db.enforce(ctx, model.ActionRead, result, accountRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully retrieved object by account ref", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", result.GetOrganizationRef()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) DeleteByAccountRef(ctx context.Context, accountRef primitive.ObjectID) error {
|
||||
db.Logger.Debug("Attempting to delete objects by account ref", mzap.ObjRef("account_ref", accountRef))
|
||||
|
||||
// Build query to find objects where accountRef matches OR is null/absent
|
||||
query := repository.WithoutOrg(accountRef)
|
||||
|
||||
// Get all candidate objects for individual permission checking
|
||||
allObjects, err := db.DBImp.Repository.ListAccountBound(ctx, query)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to list objects for delete by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check permissions for each object individually (AccountBoundStorable doesn't have permission info)
|
||||
var allowedIDs []primitive.ObjectID
|
||||
for _, obj := range allObjects {
|
||||
if err := db.enforceInterface(ctx, model.ActionUpdate, obj, accountRef); err == nil {
|
||||
allowedIDs = append(allowedIDs, *obj.GetID())
|
||||
} else if !errors.Is(err, merrors.ErrAccessDenied) {
|
||||
// If the error is something other than AccessDenied, we want to fail
|
||||
db.Logger.Warn("Error while enforcing update permission", zap.Error(err), mzap.ObjRef("object_ref", *obj.GetID()))
|
||||
return err
|
||||
}
|
||||
// If AccessDenied, we simply skip that object
|
||||
}
|
||||
|
||||
if len(allowedIDs) == 0 {
|
||||
db.Logger.Debug("No objects allowed for deletion by account ref", mzap.ObjRef("account_ref", accountRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete only the allowed objects
|
||||
allowedQuery := query.And(repository.Query().In(repository.IDField(), allowedIDs))
|
||||
if err := db.DBImp.DeleteMany(ctx, allowedQuery); err != nil {
|
||||
db.Logger.Warn("Failed to delete objects by account ref", zap.Error(err), mzap.ObjRef("account_ref", accountRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Debug("Successfully deleted objects by account ref", mzap.ObjRef("account_ref", accountRef), zap.Int("count", len(allowedIDs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *AccountBoundDBImp[T]) DeleteCascade(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.DBImp.DeleteCascade(ctx, objectRef)
|
||||
}
|
||||
|
||||
// CreateAccountBoundImp creates a concrete AccountBoundDBImp instance for internal use
|
||||
func CreateAccountBoundImp[T model.AccountBoundStorable](
|
||||
ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
enforcer Enforcer,
|
||||
collection mservice.Type,
|
||||
db *mongo.Database,
|
||||
) (*AccountBoundDBImp[T], error) {
|
||||
logger = logger.Named("account_bound")
|
||||
var policy model.PolicyDescription
|
||||
if err := pdb.GetBuiltInPolicy(ctx, mservice.Organizations, &policy); err != nil {
|
||||
logger.Warn("Failed to fetch organization policy description", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
res := &AccountBoundDBImp[T]{
|
||||
Logger: logger,
|
||||
DBImp: template.Create[T](logger, collection, db),
|
||||
Enforcer: enforcer,
|
||||
PermissionRef: policy.ID,
|
||||
Collection: collection,
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
81
api/pkg/auth/dbimpab_test.go
Normal file
81
api/pkg/auth/dbimpab_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestAccountBoundDBImp_Enforce tests the enforce method
|
||||
func TestAccountBoundDBImp_Enforce(t *testing.T) {
|
||||
logger := mlogger.Logger(zap.NewNop())
|
||||
db := &AccountBoundDBImp[model.AccountBoundStorable]{
|
||||
Logger: logger,
|
||||
PermissionRef: primitive.NewObjectID(),
|
||||
Collection: "test_collection",
|
||||
}
|
||||
|
||||
t.Run("EnforceMethodExists", func(t *testing.T) {
|
||||
// Test that the enforce method exists and can be called
|
||||
// This is a basic test to ensure the method signature is correct
|
||||
assert.NotNil(t, db.enforce)
|
||||
})
|
||||
|
||||
t.Run("PermissionRefSet", func(t *testing.T) {
|
||||
// Test that PermissionRef is properly set
|
||||
assert.NotEqual(t, primitive.NilObjectID, db.PermissionRef)
|
||||
})
|
||||
|
||||
t.Run("CollectionSet", func(t *testing.T) {
|
||||
// Test that Collection is properly set
|
||||
assert.Equal(t, "test_collection", string(db.Collection))
|
||||
})
|
||||
}
|
||||
|
||||
// TestAccountBoundDBImp_InterfaceCompliance tests that the struct implements required interfaces
|
||||
func TestAccountBoundDBImp_InterfaceCompliance(t *testing.T) {
|
||||
logger := mlogger.Logger(zap.NewNop())
|
||||
db := &AccountBoundDBImp[model.AccountBoundStorable]{
|
||||
Logger: logger,
|
||||
PermissionRef: primitive.NewObjectID(),
|
||||
Collection: "test_collection",
|
||||
}
|
||||
|
||||
t.Run("StructInitialization", func(t *testing.T) {
|
||||
// Test that the struct can be initialized
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, db.Logger)
|
||||
assert.NotEqual(t, primitive.NilObjectID, db.PermissionRef)
|
||||
assert.NotEmpty(t, db.Collection)
|
||||
})
|
||||
|
||||
t.Run("LoggerInitialization", func(t *testing.T) {
|
||||
// Test that logger is properly initialized
|
||||
assert.NotNil(t, db.Logger)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAccountBoundDBImp_ErrorHandling tests error handling patterns
|
||||
func TestAccountBoundDBImp_ErrorHandling(t *testing.T) {
|
||||
t.Run("AccessDeniedError", func(t *testing.T) {
|
||||
// Test that AccessDenied error is properly created
|
||||
err := merrors.AccessDenied("test_collection", "read", primitive.NilObjectID)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
|
||||
t.Run("ErrorTypeChecking", func(t *testing.T) {
|
||||
// Test error type checking
|
||||
accessDeniedErr := merrors.AccessDenied("test", "read", primitive.NilObjectID)
|
||||
otherErr := errors.New("other error")
|
||||
|
||||
assert.True(t, errors.Is(accessDeniedErr, merrors.ErrAccessDenied))
|
||||
assert.False(t, errors.Is(otherErr, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
32
api/pkg/auth/enforcer.go
Normal file
32
api/pkg/auth/enforcer.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type Enforcer interface {
|
||||
// Enforce checks if accountRef can do `action` on objectRef in an org (domainRef).
|
||||
Enforce(
|
||||
ctx context.Context,
|
||||
permissionRef, accountRef, orgRef, objectRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (bool, error)
|
||||
|
||||
// Enforce batch of objects
|
||||
EnforceBatch(
|
||||
ctx context.Context,
|
||||
objectRefs []model.PermissionBoundStorable,
|
||||
accountRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (map[primitive.ObjectID]bool, error)
|
||||
|
||||
// GetRoles returns the user's roles in a given org domain, plus any partial scopes if relevant.
|
||||
GetRoles(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, error)
|
||||
|
||||
// GetPermissions returns all effective permissions (with effect, object scoping) for a user in org domain.
|
||||
// Merges from all roles the user holds, plus any denies/exceptions.
|
||||
GetPermissions(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, []model.Permission, error)
|
||||
}
|
||||
52
api/pkg/auth/factory.go
Normal file
52
api/pkg/auth/factory.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/auth/internal/casbin"
|
||||
"github.com/tech/sendico/pkg/auth/internal/native"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func CreateAuth(
|
||||
logger mlogger.Logger,
|
||||
client *mongo.Client,
|
||||
db *mongo.Database,
|
||||
pdb policy.DB,
|
||||
rdb role.DB,
|
||||
config *Config,
|
||||
) (Enforcer, Manager, error) {
|
||||
lg := logger.Named("auth")
|
||||
lg.Debug("Creating enforcer...", zap.String("driver", string(config.Driver)))
|
||||
l := lg.Named(string(config.Driver))
|
||||
if config.Driver == Casbin {
|
||||
enforcer, err := casbin.NewEnforcer(l, client, config.Settings)
|
||||
if err != nil {
|
||||
lg.Warn("Failed to create enforcer", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
manager, err := casbin.NewManager(l, pdb, rdb, enforcer, config.Settings)
|
||||
if err != nil {
|
||||
lg.Warn("Failed to create managment interface", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
return enforcer, manager, nil
|
||||
}
|
||||
if config.Driver == Native {
|
||||
enforcer, err := native.NewEnforcer(l, db)
|
||||
if err != nil {
|
||||
lg.Warn("Failed to create enforcer", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
manager, err := native.NewManager(l, pdb, rdb, enforcer)
|
||||
if err != nil {
|
||||
lg.Warn("Failed to create managment interface", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
return enforcer, manager, nil
|
||||
}
|
||||
return nil, nil, merrors.InvalidArgument("Unknown enforcer type: " + string(config.Driver))
|
||||
}
|
||||
61
api/pkg/auth/helper.go
Normal file
61
api/pkg/auth/helper.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func enforceObject[T model.PermissionBoundStorable](ctx context.Context, db *template.DBImp[T], enforcer Enforcer, action model.Action, accountRef primitive.ObjectID, query builder.Query) error {
|
||||
l, err := db.ListPermissionBound(ctx, query)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Error occured while checking access rights", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
|
||||
return err
|
||||
}
|
||||
if len(l) == 0 {
|
||||
db.Logger.Debug("Access denied", mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
|
||||
return merrors.AccessDenied(db.Repository.Collection(), string(action), primitive.NilObjectID)
|
||||
}
|
||||
for _, item := range l {
|
||||
db.Logger.Debug("Object found", mzap.ObjRef("object_ref", *item.GetID()),
|
||||
mzap.ObjRef("organization_ref", item.GetOrganizationRef()),
|
||||
mzap.ObjRef("permission_ref", item.GetPermissionRef()),
|
||||
zap.String("collection", item.Collection()))
|
||||
}
|
||||
res, err := enforcer.EnforceBatch(ctx, l, accountRef, action)
|
||||
if err != nil {
|
||||
db.Logger.Warn("Failed to enforce permission", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("action", string(action)))
|
||||
}
|
||||
for objectRef, hasPermission := range res {
|
||||
if !hasPermission {
|
||||
db.Logger.Info("Permission denied for object during reordering", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionUpdate)))
|
||||
return merrors.AccessDenied(db.Repository.Collection(), string(action), objectRef)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func enforceObjectByRef[T model.PermissionBoundStorable](ctx context.Context, db *template.DBImp[T], enforcer Enforcer, action model.Action, accountRef, objectRef primitive.ObjectID) error {
|
||||
err := enforceObject(ctx, db, enforcer, action, accountRef, repository.IDFilter(objectRef))
|
||||
if err != nil {
|
||||
if errors.Is(err, merrors.ErrAccessDenied) {
|
||||
db.Logger.Debug("Access denied", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return merrors.AccessDenied(db.Repository.Collection(), string(action), objectRef)
|
||||
} else {
|
||||
db.Logger.Warn("Error occurred while checking permissions", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
29
api/pkg/auth/indexable.go
Normal file
29
api/pkg/auth/indexable.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// IndexableDB implements reordering with permission checking
|
||||
type IndexableDB[T storable.Storable] interface {
|
||||
// Reorder implements reordering with permission checking using EnforceBatch
|
||||
Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
}
|
||||
|
||||
// NewIndexableDB creates a new auth.IndexableDB instance
|
||||
func NewIndexableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getIndexable func(T) *model.Indexable,
|
||||
) IndexableDB[T] {
|
||||
return newIndexableDBImp(repo, logger, enforcer, createEmpty, getIndexable)
|
||||
}
|
||||
182
api/pkg/auth/indexableimp.go
Normal file
182
api/pkg/auth/indexableimp.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// IndexableDB implements reordering with permission checking
|
||||
type indexableDBImp[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
enforcer Enforcer
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// NewIndexableDB creates a new auth.IndexableDB instance
|
||||
func newIndexableDBImp[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getIndexable func(T) *model.Indexable,
|
||||
) IndexableDB[T] {
|
||||
return &indexableDBImp[T]{
|
||||
repo: repo,
|
||||
logger: logger.Named("indexable"),
|
||||
enforcer: enforcer,
|
||||
createEmpty: createEmpty,
|
||||
getIndexable: getIndexable,
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder implements reordering with permission checking using EnforceBatch
|
||||
func (db *indexableDBImp[T]) Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
// Get current object to find its index
|
||||
obj := db.createEmpty()
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for reordering", zap.Error(err), zap.Int("new_index", newIndex),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract index from the object
|
||||
indexable := db.getIndexable(obj)
|
||||
currentIndex := indexable.Index
|
||||
if currentIndex == newIndex {
|
||||
db.logger.Debug("No reordering needed - same index", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Determine which objects will be affected by the reordering
|
||||
var affectedObjects []model.PermissionBoundStorable
|
||||
|
||||
if currentIndex < newIndex {
|
||||
// Moving down: items between currentIndex+1 and newIndex will be shifted up by -1
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
|
||||
And(repository.IndexOpFilter(newIndex, builder.Lte))
|
||||
|
||||
// Get all affected objects using ListPermissionBound
|
||||
objects, err := db.repo.ListPermissionBound(ctx, reorderFilter)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to get affected objects for reordering (moving down)",
|
||||
zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
affectedObjects = append(affectedObjects, objects...)
|
||||
db.logger.Debug("Found affected objects for moving down",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(objects)))
|
||||
} else {
|
||||
// Moving up: items between newIndex and currentIndex-1 will be shifted down by +1
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(newIndex, builder.Gte)).
|
||||
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
|
||||
|
||||
// Get all affected objects using ListPermissionBound
|
||||
objects, err := db.repo.ListPermissionBound(ctx, reorderFilter)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to get affected objects for reordering (moving up)", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
affectedObjects = append(affectedObjects, objects...)
|
||||
db.logger.Debug("Found affected objects for moving up", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(objects)))
|
||||
}
|
||||
|
||||
// Add the target object to the list of objects that need permission checking
|
||||
targetObjects, err := db.repo.ListPermissionBound(ctx, repository.IDFilter(objectRef))
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to get target object for permission checking", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
if len(targetObjects) > 0 {
|
||||
affectedObjects = append(affectedObjects, targetObjects[0])
|
||||
}
|
||||
|
||||
// Check permissions for all affected objects using EnforceBatch
|
||||
db.logger.Debug("Checking permissions for reordering", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(affectedObjects)),
|
||||
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
|
||||
|
||||
permissions, err := db.enforcer.EnforceBatch(ctx, affectedObjects, accountRef, model.ActionUpdate)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to check permissions for reordering", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("affected_count", len(affectedObjects)))
|
||||
return merrors.Internal("failed to check permissions for reordering")
|
||||
}
|
||||
|
||||
// Verify all objects have update permission
|
||||
for resObjectRef, hasPermission := range permissions {
|
||||
if !hasPermission {
|
||||
db.logger.Info("Permission denied for object during reordering", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(model.ActionUpdate)))
|
||||
return merrors.AccessDenied(db.repo.Collection(), string(model.ActionUpdate), resObjectRef)
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Debug("All permissions granted, proceeding with reordering", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("permission_count", len(permissions)))
|
||||
|
||||
// All permissions checked, proceed with reordering
|
||||
if currentIndex < newIndex {
|
||||
// Moving down: shift items between currentIndex+1 and newIndex up by -1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), -1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
|
||||
And(repository.IndexOpFilter(newIndex, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to shift objects during reordering (moving down)", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex), zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving down)", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("updated_count", updatedCount))
|
||||
} else {
|
||||
// Moving up: shift items between newIndex and currentIndex-1 down by +1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), 1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(newIndex, builder.Gte)).
|
||||
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to shift objects during reordering (moving up)", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex), zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving up)", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("updated_count", updatedCount))
|
||||
}
|
||||
|
||||
// Update the target object to new index
|
||||
if err := db.repo.Patch(ctx, objectRef, repository.Patch().Set(repository.IndexField(), newIndex)); err != nil {
|
||||
db.logger.Warn("Failed to update target object index", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("current_index", currentIndex), zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully reordered object with permission checking",
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), zap.Int("old_index", currentIndex),
|
||||
zap.Int("new_index", newIndex), zap.Int("affected_count", len(affectedObjects)))
|
||||
return nil
|
||||
}
|
||||
23
api/pkg/auth/internal/casbin/action.go
Normal file
23
api/pkg/auth/internal/casbin/action.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func stringToAction(actionStr string) (model.Action, error) {
|
||||
switch actionStr {
|
||||
case string(model.ActionCreate):
|
||||
return model.ActionCreate, nil
|
||||
case string(model.ActionRead):
|
||||
return model.ActionRead, nil
|
||||
case string(model.ActionUpdate):
|
||||
return model.ActionUpdate, nil
|
||||
case string(model.ActionDelete):
|
||||
return model.ActionDelete, nil
|
||||
default:
|
||||
return "", merrors.InvalidArgument(fmt.Sprintf("invalid action: %s", actionStr))
|
||||
}
|
||||
}
|
||||
126
api/pkg/auth/internal/casbin/config/config.go
Normal file
126
api/pkg/auth/internal/casbin/config/config.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
mongodbadapter "github.com/casbin/mongodb-adapter/v3"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AdapterConfig struct {
|
||||
DatabaseName *string `mapstructure:"database_name"`
|
||||
DatabaseNameEnv *string `mapstructure:"database_name_env"`
|
||||
CollectionName *string `mapstructure:"collection_name"`
|
||||
CollectionNameEnv *string `mapstructure:"collection_name_env"`
|
||||
TimeoutSeconds *int `mapstructure:"timeout_seconds"`
|
||||
TimeoutSecondsEnv *string `mapstructure:"timeout_seconds_env"`
|
||||
IsFiltered *bool `mapstructure:"is_filtered"`
|
||||
IsFilteredEnv *string `mapstructure:"is_filtered_env"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ModelPath *string `mapstructure:"model_path"`
|
||||
ModelPathEnv *string `mapstructure:"model_path_env"`
|
||||
Adapter *AdapterConfig `mapstructure:"adapter"`
|
||||
}
|
||||
|
||||
type EnforcerConfig struct {
|
||||
ModelPath string
|
||||
Adapter *mongodbadapter.AdapterConfig
|
||||
}
|
||||
|
||||
func getEnvValue(logger mlogger.Logger, varName, envVarName string, value, envValue *string) string {
|
||||
if value != nil && envValue != nil {
|
||||
logger.Warn("Both variable and environment variable are set, using environment variable value",
|
||||
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.String("value", *value), zap.String("env_value", os.Getenv(*envValue)))
|
||||
}
|
||||
|
||||
if envValue != nil {
|
||||
return os.Getenv(*envValue)
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
return *value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func getEnvIntValue(logger mlogger.Logger, varName, envVarName string, value *int, envValue *string) int {
|
||||
if value != nil && envValue != nil {
|
||||
logger.Warn("Both variable and environment variable are set, using environment variable value",
|
||||
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.Int("value", *value), zap.String("env_value", os.Getenv(*envValue)))
|
||||
}
|
||||
|
||||
if envValue != nil {
|
||||
envStr := os.Getenv(*envValue)
|
||||
if envStr != "" {
|
||||
if parsed, err := time.ParseDuration(envStr + "s"); err == nil {
|
||||
return int(parsed.Seconds())
|
||||
}
|
||||
logger.Warn("Invalid environment variable value for timeout", zap.String("environment_variable", envVarName), zap.String("value", envStr))
|
||||
}
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
return *value
|
||||
}
|
||||
|
||||
return 30 // Default timeout in seconds
|
||||
}
|
||||
|
||||
func getEnvBoolValue(logger mlogger.Logger, varName, envVarName string, value *bool, envValue *string) bool {
|
||||
if value != nil && envValue != nil {
|
||||
logger.Warn("Both variable and environment variable are set, using environment variable value",
|
||||
zap.String("variable", varName), zap.String("environment_variable", envVarName), zap.Bool("value", *value), zap.String("env_value", os.Getenv(*envValue)))
|
||||
}
|
||||
|
||||
if envValue != nil {
|
||||
envStr := os.Getenv(*envValue)
|
||||
if envStr == "true" || envStr == "1" {
|
||||
return true
|
||||
} else if envStr == "false" || envStr == "0" {
|
||||
return false
|
||||
}
|
||||
logger.Warn("Invalid environment variable value for boolean", zap.String("environment_variable", envVarName), zap.String("value", envStr))
|
||||
}
|
||||
|
||||
if value != nil {
|
||||
return *value
|
||||
}
|
||||
|
||||
return false // Default for boolean
|
||||
}
|
||||
|
||||
func PrepareConfig(logger mlogger.Logger, config *Config) (*EnforcerConfig, error) {
|
||||
if config == nil {
|
||||
return nil, merrors.Internal("No configuration provided")
|
||||
}
|
||||
|
||||
adapter := &mongodbadapter.AdapterConfig{
|
||||
DatabaseName: getEnvValue(logger, "database_name", "database_name_env", config.Adapter.DatabaseName, config.Adapter.DatabaseNameEnv),
|
||||
CollectionName: getEnvValue(logger, "collection_name", "collection_name_env", config.Adapter.CollectionName, config.Adapter.CollectionNameEnv),
|
||||
Timeout: time.Duration(getEnvIntValue(logger, "timeout_seconds", "timeout_seconds_env", config.Adapter.TimeoutSeconds, config.Adapter.TimeoutSecondsEnv)) * time.Second,
|
||||
IsFiltered: getEnvBoolValue(logger, "is_filtered", "is_filtered_env", config.Adapter.IsFiltered, config.Adapter.IsFilteredEnv),
|
||||
}
|
||||
|
||||
if len(adapter.DatabaseName) == 0 {
|
||||
logger.Error("Database name is not set")
|
||||
return nil, merrors.InvalidArgument("database name must be provided")
|
||||
}
|
||||
|
||||
path := getEnvValue(logger, "model_path", "model_path_env", config.ModelPath, config.ModelPathEnv)
|
||||
|
||||
logger.Info("Configuration prepared",
|
||||
zap.String("model_path", path),
|
||||
zap.String("database_name", adapter.DatabaseName),
|
||||
zap.String("collection_name", adapter.CollectionName),
|
||||
zap.Duration("timeout", adapter.Timeout),
|
||||
zap.Bool("is_filtered", adapter.IsFiltered),
|
||||
)
|
||||
|
||||
return &EnforcerConfig{ModelPath: path, Adapter: adapter}, nil
|
||||
}
|
||||
206
api/pkg/auth/internal/casbin/enforcer.go
Normal file
206
api/pkg/auth/internal/casbin/enforcer.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// casbin_enforcer.go
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/casbin/casbin/v2"
|
||||
"github.com/tech/sendico/pkg/auth/anyobject"
|
||||
cc "github.com/tech/sendico/pkg/auth/internal/casbin/config"
|
||||
"github.com/tech/sendico/pkg/auth/internal/casbin/serialization"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CasbinEnforcer implements the Enforcer interface using Casbin.
|
||||
type CasbinEnforcer struct {
|
||||
logger mlogger.Logger
|
||||
enforcer *casbin.Enforcer
|
||||
roleSerializer serialization.Role
|
||||
permissionSerializer serialization.Policy
|
||||
}
|
||||
|
||||
// NewCasbinEnforcer initializes a new CasbinEnforcer with a MongoDB adapter, logger, and PolicySerializer.
|
||||
// The 'domain' parameter is no longer stored internally, as the interface requires passing a domainRef per method call.
|
||||
func NewEnforcer(
|
||||
logger mlogger.Logger,
|
||||
client *mongo.Client,
|
||||
settings model.SettingsT,
|
||||
) (*CasbinEnforcer, error) {
|
||||
var config cc.Config
|
||||
if err := mapstructure.Decode(settings, &config); err != nil {
|
||||
logger.Warn("Failed to decode Casbin configuration", zap.Error(err), zap.Any("settings", settings))
|
||||
return nil, merrors.Internal("failed to decode Casbin configuration")
|
||||
}
|
||||
|
||||
// Create a Casbin adapter + enforcer from your config and client.
|
||||
l := logger.Named("enforcer")
|
||||
e, err := createAdapter(l, &config, client)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to create Casbin enforcer", zap.Error(err))
|
||||
return nil, merrors.Internal("failed to create Casbin enforcer")
|
||||
}
|
||||
|
||||
logger.Info("Casbin enforcer created")
|
||||
return &CasbinEnforcer{
|
||||
logger: l,
|
||||
enforcer: e,
|
||||
permissionSerializer: serialization.NewPolicySerializer(),
|
||||
roleSerializer: serialization.NewRoleSerializer(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Enforce checks if a user has the specified action permission on an object within a domain.
|
||||
func (c *CasbinEnforcer) Enforce(
|
||||
_ context.Context,
|
||||
permissionRef, accountRef, organizationRef, objectRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (bool, error) {
|
||||
// Convert ObjectIDs to strings for Casbin
|
||||
account := accountRef.Hex()
|
||||
organization := organizationRef.Hex()
|
||||
permission := permissionRef.Hex()
|
||||
object := anyobject.ID
|
||||
if objectRef != primitive.NilObjectID {
|
||||
object = objectRef.Hex()
|
||||
}
|
||||
act := string(action)
|
||||
|
||||
c.logger.Debug("Enforcing policy",
|
||||
zap.String("account", account), zap.String("organization", organization),
|
||||
zap.String("permission", permission), zap.String("object", object),
|
||||
zap.String("action", act))
|
||||
|
||||
// Perform the enforcement
|
||||
result, err := c.enforcer.Enforce(account, organization, permission, object, act)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to enforce policy", zap.Error(err),
|
||||
zap.String("account", account), zap.String("organization", organization),
|
||||
zap.String("permission", permission), zap.String("object", object),
|
||||
zap.String("action", act))
|
||||
return false, err
|
||||
}
|
||||
|
||||
c.logger.Debug("Policy enforcement result", zap.Bool("result", result))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EnforceBatch checks a user’s permission for multiple objects at once.
|
||||
// It returns a map from objectRef -> boolean indicating whether access is granted.
|
||||
func (c *CasbinEnforcer) EnforceBatch(
|
||||
ctx context.Context,
|
||||
objectRefs []model.PermissionBoundStorable,
|
||||
accountRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (map[primitive.ObjectID]bool, error) {
|
||||
results := make(map[primitive.ObjectID]bool, len(objectRefs))
|
||||
for _, desc := range objectRefs {
|
||||
ok, err := c.Enforce(ctx, desc.GetPermissionRef(), accountRef, desc.GetOrganizationRef(), *desc.GetID(), action)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to enforce", zap.Error(err), mzap.ObjRef("permission_ref", desc.GetPermissionRef()),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", desc.GetOrganizationRef()),
|
||||
mzap.ObjRef("object_ref", *desc.GetID()), zap.String("action", string(action)))
|
||||
return nil, err
|
||||
}
|
||||
results[*desc.GetID()] = ok
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetRoles retrieves all roles assigned to the user within the domain.
|
||||
func (c *CasbinEnforcer) GetRoles(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, error) {
|
||||
sub := accountRef.Hex()
|
||||
dom := orgRef.Hex()
|
||||
|
||||
c.logger.Debug("Fetching roles for user", zap.String("subject", sub), zap.String("domain", dom))
|
||||
|
||||
// Get all roles for the user in the domain
|
||||
sroles, err := c.enforcer.GetFilteredGroupingPolicy(0, sub, "", dom)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to get roles from policies", zap.Error(err),
|
||||
zap.String("account_ref", sub), zap.String("organization_ref", dom),
|
||||
)
|
||||
return nil, merrors.Internal("failed to fetch roles from policies")
|
||||
}
|
||||
|
||||
roles := make([]model.Role, 0, len(sroles))
|
||||
for _, srole := range sroles {
|
||||
role, err := c.roleSerializer.Deserialize(srole)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to deserialize role", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
roles = append(roles, *role)
|
||||
}
|
||||
|
||||
c.logger.Debug("Roles fetched successfully", zap.Int("count", len(roles)))
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// GetPermissions retrieves all effective policies for the user within the domain.
|
||||
func (c *CasbinEnforcer) GetPermissions(ctx context.Context, accountRef, orgRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
|
||||
c.logger.Debug("Fetching policies for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
|
||||
|
||||
// Step 1: Retrieve all roles assigned to the user within the domain
|
||||
roles, err := c.GetRoles(ctx, accountRef, orgRef)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to get roles", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Map to hold unique policies
|
||||
permissionsMap := make(map[string]*model.Permission)
|
||||
for _, role := range roles {
|
||||
// Step 2a: Retrieve all policies associated with the role within the domain
|
||||
policies, err := c.enforcer.GetFilteredPolicy(0, role.DescriptionRef.Hex())
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to get policies for role", zap.Error(err), mzap.ObjRef("role_ref", role.DescriptionRef))
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 2b: Process each policy to extract Permission, Action, and Effect
|
||||
for _, policy := range policies {
|
||||
|
||||
if len(policy) < 5 {
|
||||
c.logger.Warn("Incomplete policy encountered", zap.Strings("policy", policy))
|
||||
continue // Ensure the policy line has enough fields
|
||||
}
|
||||
|
||||
// Deserialize the policy using
|
||||
deserializedPolicy, err := c.permissionSerializer.Deserialize(policy)
|
||||
if err != nil {
|
||||
c.logger.Warn("Failed to deserialize policy", zap.Error(err), zap.Strings("policy", policy))
|
||||
continue
|
||||
}
|
||||
|
||||
// Construct a unique key combining Permission ID and Action to prevent duplicates
|
||||
policyKey := deserializedPolicy.DescriptionRef.Hex() + ":" + string(deserializedPolicy.Effect.Action)
|
||||
if _, exists := permissionsMap[policyKey]; exists {
|
||||
continue // Policy-action pair already accounted for
|
||||
}
|
||||
|
||||
// Add the Policy to the map
|
||||
permissionsMap[policyKey] = &model.Permission{
|
||||
RolePolicy: *deserializedPolicy,
|
||||
AccountRef: accountRef,
|
||||
}
|
||||
c.logger.Debug("Policy added to policyMap", zap.Any("policy_key", policyKey))
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the map to a slice
|
||||
permissions := make([]model.Permission, 0, len(permissionsMap))
|
||||
for _, permission := range permissionsMap {
|
||||
permissions = append(permissions, *permission)
|
||||
}
|
||||
|
||||
c.logger.Debug("Permissions fetched successfully", zap.Int("count", len(permissions)))
|
||||
return roles, permissions, nil
|
||||
}
|
||||
34
api/pkg/auth/internal/casbin/factory.go
Normal file
34
api/pkg/auth/internal/casbin/factory.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"github.com/casbin/casbin/v2"
|
||||
mongodbadapter "github.com/casbin/mongodb-adapter/v3"
|
||||
cc "github.com/tech/sendico/pkg/auth/internal/casbin/config"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func createAdapter(logger mlogger.Logger, config *cc.Config, client *mongo.Client) (*casbin.Enforcer, error) {
|
||||
dbc, err := cc.PrepareConfig(logger, config)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to prepare database configuration", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
adapter, err := mongodbadapter.NewAdapterByDB(client, dbc.Adapter)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to create DB adapter", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := casbin.NewEnforcer(dbc.ModelPath, adapter, NewCasbinLogger(logger))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to create permissions enforcer", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
e.EnableAutoSave(true)
|
||||
|
||||
// No need to manually load policy. Casbin does it for us
|
||||
return e, nil
|
||||
}
|
||||
61
api/pkg/auth/internal/casbin/logger.go
Normal file
61
api/pkg/auth/internal/casbin/logger.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CasbinZapLogger wraps a zap.Logger to implement Casbin's Logger interface.
|
||||
type CasbinZapLogger struct {
|
||||
logger mlogger.Logger
|
||||
}
|
||||
|
||||
// NewCasbinLogger constructs a new CasbinZapLogger.
|
||||
func NewCasbinLogger(logger mlogger.Logger) *CasbinZapLogger {
|
||||
return &CasbinZapLogger{
|
||||
logger: logger.Named("driver"),
|
||||
}
|
||||
}
|
||||
|
||||
// EnableLog enables or disables logging.
|
||||
func (l *CasbinZapLogger) EnableLog(_ bool) {
|
||||
// ignore
|
||||
}
|
||||
|
||||
// IsEnabled returns whether logging is currently enabled.
|
||||
func (l *CasbinZapLogger) IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// LogModel is called by Casbin when loading model settings (you can customize if you want).
|
||||
func (l *CasbinZapLogger) LogModel(m [][]string) {
|
||||
l.logger.Info("Model loaded", zap.Any("model", m))
|
||||
}
|
||||
|
||||
func (l *CasbinZapLogger) LogPolicy(m map[string][][]string) {
|
||||
l.logger.Info("Policy loaded", zap.Int("entries", len(m)))
|
||||
}
|
||||
|
||||
func (l *CasbinZapLogger) LogError(err error, msg ...string) {
|
||||
// If no custom message was passed, log a generic one
|
||||
if len(msg) == 0 {
|
||||
l.logger.Warn("Error occurred", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, join any provided messages and include them
|
||||
l.logger.Warn(strings.Join(msg, " "), zap.Error(err))
|
||||
}
|
||||
|
||||
// LogEnforce is called by Casbin to log each Enforce() call if logging is enabled.
|
||||
func (l *CasbinZapLogger) LogEnforce(matcher string, request []any, result bool, explains [][]string) {
|
||||
l.logger.Debug("Enforcing policy...", zap.String("matcher", matcher), zap.Any("request", request),
|
||||
zap.Bool("result", result), zap.Any("explains", explains))
|
||||
}
|
||||
|
||||
// LogRole is called by Casbin when role manager adds or deletes a role.
|
||||
func (l *CasbinZapLogger) LogRole(roles []string) {
|
||||
l.logger.Debug("Changing roles...", zap.Strings("roles", roles))
|
||||
}
|
||||
54
api/pkg/auth/internal/casbin/manager.go
Normal file
54
api/pkg/auth/internal/casbin/manager.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// package casbin
|
||||
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/management"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CasbinManager implements the auth.Manager interface by aggregating Role and Permission managers.
|
||||
type CasbinManager struct {
|
||||
logger mlogger.Logger
|
||||
roleManager management.Role
|
||||
permManager management.Permission
|
||||
}
|
||||
|
||||
// NewManager creates a new CasbinManager with specified domains and role-domain mappings.
|
||||
func NewManager(
|
||||
l mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
rdb role.DB,
|
||||
enforcer *CasbinEnforcer,
|
||||
settings model.SettingsT,
|
||||
) (*CasbinManager, error) {
|
||||
logger := l.Named("manager")
|
||||
|
||||
var pdesc model.PolicyDescription
|
||||
if err := pdb.GetBuiltInPolicy(context.Background(), "roles", &pdesc); err != nil {
|
||||
logger.Warn("Failed to fetch roles permission reference", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CasbinManager{
|
||||
logger: logger,
|
||||
roleManager: NewRoleManager(logger, enforcer, pdesc.ID, rdb),
|
||||
permManager: NewPermissionManager(logger, enforcer),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Permission returns the Permission manager.
|
||||
func (m *CasbinManager) Permission() management.Permission {
|
||||
return m.permManager
|
||||
}
|
||||
|
||||
// Role returns the Role manager.
|
||||
func (m *CasbinManager) Role() management.Role {
|
||||
return m.roleManager
|
||||
}
|
||||
54
api/pkg/auth/internal/casbin/models/auth.conf
Normal file
54
api/pkg/auth/internal/casbin/models/auth.conf
Normal file
@@ -0,0 +1,54 @@
|
||||
######################################################
|
||||
# Request Definition
|
||||
######################################################
|
||||
[request_definition]
|
||||
# Explanation:
|
||||
# - `accountRef`: The account (user) making the request.
|
||||
# - `organizationRef`: The organization in which the role applies.
|
||||
# - `permissionRef`: The specific permission being requested.
|
||||
# - `objectRef`: The object/resource being accessed (specific object or all objects).
|
||||
# - `action`: The action being requested (CRUD: read, write, update, delete).
|
||||
r = accountRef, organizationRef, permissionRef, objectRef, action
|
||||
|
||||
|
||||
######################################################
|
||||
# Policy Definition
|
||||
######################################################
|
||||
[policy_definition]
|
||||
# Explanation:
|
||||
# - `roleRef`: The role to which the policy is assigned.
|
||||
# - `organizationRef`: The organization in which the role applies.
|
||||
# - `permissionRef`: The permission associated with the policy.
|
||||
# - `objectRef`: The specific object/resource the policy applies to (or all objects).
|
||||
# - `action`: The CRUD action permitted or denied.
|
||||
# - `eft`: Effect of the policy (`allow` or `deny`).
|
||||
p = roleRef, organizationRef, permissionRef, objectRef, action, eft
|
||||
|
||||
|
||||
######################################################
|
||||
# Role Definition
|
||||
######################################################
|
||||
[role_definition]
|
||||
# Explanation:
|
||||
# - Maps `accountRef` (user) to `roleRef` (role) within `organizationRef` (scope).
|
||||
# Casbin requires underscores for placeholders, so we do not literally use accountRef, roleRef, etc. here.
|
||||
g = _, _, _
|
||||
|
||||
|
||||
######################################################
|
||||
# Policy Effect
|
||||
######################################################
|
||||
[policy_effect]
|
||||
# Explanation:
|
||||
# - Grants access if any `allow` policy matches and no `deny` policies match.
|
||||
e = some(where (p.eft == allow)) && !some(where (p.eft == deny))
|
||||
|
||||
|
||||
######################################################
|
||||
# Matchers
|
||||
######################################################
|
||||
[matchers]
|
||||
# Explanation:
|
||||
# - Checks if the user (accountRef) belongs to the roleRef within an organizationRef via `g()`.
|
||||
# - Ensures the organizationRef, permissionRef, objectRef, and action match the policy.
|
||||
m = g(r.accountRef, p.roleRef, r.organizationRef) && r.organizationRef == p.organizationRef && r.permissionRef == p.permissionRef && (p.objectRef == r.objectRef || p.objectRef == "*") && r.action == p.action
|
||||
167
api/pkg/auth/internal/casbin/permissions.go
Normal file
167
api/pkg/auth/internal/casbin/permissions.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/anyobject"
|
||||
"github.com/tech/sendico/pkg/auth/internal/casbin/serialization"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CasbinPermissionManager manages permissions using Casbin.
|
||||
type CasbinPermissionManager struct {
|
||||
logger mlogger.Logger // Logger for logging operations
|
||||
enforcer *CasbinEnforcer // Casbin enforcer for managing policies
|
||||
serializer serialization.Policy // Serializer for converting policies to/from Casbin
|
||||
}
|
||||
|
||||
// GrantToRole adds a permission to a role in Casbin.
|
||||
func (m *CasbinPermissionManager) GrantToRole(ctx context.Context, policy *model.RolePolicy) error {
|
||||
objRef := anyobject.ID
|
||||
if (policy.ObjectRef != nil) && (*policy.ObjectRef != primitive.NilObjectID) {
|
||||
objRef = policy.ObjectRef.Hex()
|
||||
}
|
||||
|
||||
m.logger.Debug("Granting permission to role",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)),
|
||||
zap.String("effect", string(policy.Effect.Effect)),
|
||||
)
|
||||
|
||||
// Serialize permission
|
||||
serializedPolicy, err := m.serializer.Serialize(policy)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to serialize permission while granting permission", zap.Error(err),
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
mzap.ObjRef("organization_ref", policy.OrganizationRef),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Add policy to Casbin
|
||||
added, err := m.enforcer.enforcer.AddPolicy(serializedPolicy...)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to add policy to Casbin", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if added {
|
||||
m.logger.Info("Policy added to Casbin",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
)
|
||||
} else {
|
||||
m.logger.Warn("Policy already exists in Casbin",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeFromRole removes a permission from a role in Casbin.
|
||||
func (m *CasbinPermissionManager) RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error {
|
||||
objRef := anyobject.ID
|
||||
if policy.ObjectRef != nil {
|
||||
objRef = policy.ObjectRef.Hex()
|
||||
}
|
||||
m.logger.Debug("Revoking permission from role",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)),
|
||||
zap.String("effect", string(policy.Effect.Effect)),
|
||||
)
|
||||
|
||||
// Serialize policy
|
||||
serializedPolicy, err := m.serializer.Serialize(policy)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to serialize policy while revoking permission from role",
|
||||
zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("policy_ref", policy.DescriptionRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove policy from Casbin
|
||||
removed, err := m.enforcer.enforcer.RemovePolicy(serializedPolicy...)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to remove policy from Casbin", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
if removed {
|
||||
m.logger.Info("Policy removed from Casbin",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
)
|
||||
} else {
|
||||
m.logger.Warn("Policy does not exist in Casbin",
|
||||
mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef),
|
||||
zap.String("object_ref", objRef),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPolicies retrieves all policies for a specific role.
|
||||
func (m *CasbinPermissionManager) GetPolicies(
|
||||
ctx context.Context,
|
||||
roleRef primitive.ObjectID,
|
||||
) ([]model.RolePolicy, error) {
|
||||
m.logger.Debug("Fetching policies for role", mzap.ObjRef("role_ref", roleRef))
|
||||
|
||||
// Retrieve Casbin policies for the role
|
||||
policies, err := m.enforcer.enforcer.GetFilteredPolicy(0, roleRef.Hex())
|
||||
if err != nil {
|
||||
m.logger.Warn("Failed to get policies", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
|
||||
return nil, err
|
||||
}
|
||||
if len(policies) == 0 {
|
||||
m.logger.Info("No policies found for role", mzap.ObjRef("role_ref", roleRef))
|
||||
return nil, merrors.NoData("no policies")
|
||||
}
|
||||
|
||||
// Deserialize policies
|
||||
var result []model.RolePolicy
|
||||
for _, policy := range policies {
|
||||
permission, err := m.serializer.Deserialize(policy)
|
||||
if err != nil {
|
||||
m.logger.Warn("Failed to deserialize policy", zap.Error(err), zap.String("policy", policy[0]))
|
||||
continue
|
||||
}
|
||||
result = append(result, *permission)
|
||||
}
|
||||
|
||||
m.logger.Debug("Policies fetched successfully", mzap.ObjRef("role_ref", roleRef), zap.Int("count", len(result)))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Save persists changes to the Casbin policy store.
|
||||
func (m *CasbinPermissionManager) Save() error {
|
||||
if err := m.enforcer.enforcer.SavePolicy(); err != nil {
|
||||
m.logger.Error("Failed to save policies in Casbin", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
m.logger.Info("Policies successfully saved in Casbin")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPermissionManager(logger mlogger.Logger, enforcer *CasbinEnforcer) *CasbinPermissionManager {
|
||||
return &CasbinPermissionManager{
|
||||
logger: logger.Named("permission"),
|
||||
enforcer: enforcer,
|
||||
serializer: serialization.NewPolicySerializer(),
|
||||
}
|
||||
}
|
||||
209
api/pkg/auth/internal/casbin/role.go
Normal file
209
api/pkg/auth/internal/casbin/role.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleManager manages roles using Casbin.
|
||||
type RoleManager struct {
|
||||
logger mlogger.Logger
|
||||
enforcer *CasbinEnforcer
|
||||
rdb role.DB
|
||||
rolePermissionRef primitive.ObjectID
|
||||
}
|
||||
|
||||
// NewRoleManager creates a new RoleManager.
|
||||
func NewRoleManager(logger mlogger.Logger, enforcer *CasbinEnforcer, rolePermissionRef primitive.ObjectID, rdb role.DB) *RoleManager {
|
||||
return &RoleManager{
|
||||
logger: logger.Named("role"),
|
||||
enforcer: enforcer,
|
||||
rdb: rdb,
|
||||
rolePermissionRef: rolePermissionRef,
|
||||
}
|
||||
}
|
||||
|
||||
// validateObjectIDs ensures that all provided ObjectIDs are non-zero.
|
||||
func (rm *RoleManager) validateObjectIDs(ids ...primitive.ObjectID) error {
|
||||
for _, id := range ids {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("Object references cannot be zero")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePolicies removes policies based on the provided filter and logs the results.
|
||||
func (rm *RoleManager) removePolicies(policyType, role string, roleRef primitive.ObjectID) error {
|
||||
filterIndex := 1
|
||||
if policyType == "permission" {
|
||||
filterIndex = 0
|
||||
}
|
||||
policies, err := rm.enforcer.enforcer.GetFilteredPolicy(filterIndex, role)
|
||||
if err != nil {
|
||||
rm.logger.Warn("Failed to fetch "+policyType+" policies", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
args := make([]any, len(policy))
|
||||
for i, v := range policy {
|
||||
args[i] = v
|
||||
}
|
||||
var removed bool
|
||||
var removeErr error
|
||||
if policyType == "grouping" {
|
||||
removed, removeErr = rm.enforcer.enforcer.RemoveGroupingPolicy(args...)
|
||||
} else {
|
||||
removed, removeErr = rm.enforcer.enforcer.RemovePolicy(args...)
|
||||
}
|
||||
|
||||
if removeErr != nil {
|
||||
rm.logger.Warn("Failed to remove "+policyType+" policy for role", zap.Error(removeErr), mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy))
|
||||
return removeErr
|
||||
}
|
||||
if removed {
|
||||
rm.logger.Info("Removed "+policyType+" policy for role", mzap.ObjRef("role_ref", roleRef), zap.Strings("policy", policy))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchRolesFromPolicies retrieves and converts policies to roles.
|
||||
func (rm *RoleManager) fetchRolesFromPolicies(policies [][]string, orgRef primitive.ObjectID) []model.RoleDescription {
|
||||
roles := make([]model.RoleDescription, 0, len(policies))
|
||||
for _, policy := range policies {
|
||||
if len(policy) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
roleID, err := primitive.ObjectIDFromHex(policy[1])
|
||||
if err != nil {
|
||||
rm.logger.Warn("Invalid role ID", zap.String("roleID", policy[1]))
|
||||
continue
|
||||
}
|
||||
roles = append(roles, model.RoleDescription{Base: storable.Base{ID: roleID}, OrganizationRef: orgRef})
|
||||
}
|
||||
return roles
|
||||
}
|
||||
|
||||
// Create creates a new role in an organization.
|
||||
func (rm *RoleManager) Create(ctx context.Context, orgRef primitive.ObjectID, description *model.Describable) (*model.RoleDescription, error) {
|
||||
if err := rm.validateObjectIDs(orgRef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
role := &model.RoleDescription{
|
||||
Describable: *description,
|
||||
OrganizationRef: orgRef,
|
||||
}
|
||||
if err := rm.rdb.Create(ctx, role); err != nil {
|
||||
rm.logger.Warn("Failed to create role", zap.Error(err), mzap.ObjRef("organiztion_ref", orgRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rm.logger.Info("Role created successfully", mzap.StorableRef(role), mzap.ObjRef("organization_ref", orgRef))
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// Assign assigns a role to a user in the given organization.
|
||||
func (rm *RoleManager) Assign(ctx context.Context, role *model.Role) error {
|
||||
if err := rm.validateObjectIDs(role.DescriptionRef, role.AccountRef, role.OrganizationRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sub := role.AccountRef.Hex()
|
||||
roleID := role.DescriptionRef.Hex()
|
||||
domain := role.OrganizationRef.Hex()
|
||||
|
||||
added, err := rm.enforcer.enforcer.AddGroupingPolicy(sub, roleID, domain)
|
||||
return rm.logPolicyResult("assign", added, err, role.DescriptionRef, role.AccountRef, role.OrganizationRef)
|
||||
}
|
||||
|
||||
// Delete removes a role entirely and cleans up associated Casbin policies.
|
||||
func (rm *RoleManager) Delete(ctx context.Context, roleRef primitive.ObjectID) error {
|
||||
if err := rm.validateObjectIDs(roleRef); err != nil {
|
||||
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rm.rdb.Delete(ctx, roleRef); err != nil {
|
||||
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
role := roleRef.Hex()
|
||||
|
||||
// Remove grouping policies
|
||||
if err := rm.removePolicies("grouping", role, roleRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove permission policies
|
||||
if err := rm.removePolicies("permission", role, roleRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// // Save changes
|
||||
// if err := rm.enforcer.enforcer.SavePolicy(); err != nil {
|
||||
// rm.logger.Warn("Failed to save Casbin policies after role deletion",
|
||||
// zap.Error(err),
|
||||
// mzap.ObjRef("role_ref", roleRef),
|
||||
// )
|
||||
// return err
|
||||
// }
|
||||
|
||||
rm.logger.Info("Role deleted successfully along with associated policies", mzap.ObjRef("role_ref", roleRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke removes a role from a user.
|
||||
func (rm *RoleManager) Revoke(ctx context.Context, roleRef, accountRef, orgRef primitive.ObjectID) error {
|
||||
if err := rm.validateObjectIDs(roleRef, accountRef, orgRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sub := accountRef.Hex()
|
||||
role := roleRef.Hex()
|
||||
domain := orgRef.Hex()
|
||||
|
||||
removed, err := rm.enforcer.enforcer.RemoveGroupingPolicy(sub, role, domain)
|
||||
return rm.logPolicyResult("revoke", removed, err, roleRef, accountRef, orgRef)
|
||||
}
|
||||
|
||||
// logPolicyResult logs results for Assign and Revoke.
|
||||
func (rm *RoleManager) logPolicyResult(action string, result bool, err error, roleRef, accountRef, orgRef primitive.ObjectID) error {
|
||||
if err != nil {
|
||||
rm.logger.Warn("Failed to "+action+" role", zap.Error(err), mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
|
||||
return err
|
||||
}
|
||||
msg := "Role " + action + "ed successfully"
|
||||
if !result {
|
||||
msg = "Role already " + action + "ed"
|
||||
}
|
||||
rm.logger.Info(msg, mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", orgRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// List retrieves all roles in an organization or all roles if orgRef is zero.
|
||||
func (rm *RoleManager) List(ctx context.Context, orgRef primitive.ObjectID) ([]model.RoleDescription, error) {
|
||||
domain := orgRef.Hex()
|
||||
groupingPolicies, err := rm.enforcer.enforcer.GetFilteredGroupingPolicy(2, domain)
|
||||
if err != nil {
|
||||
rm.logger.Warn("Failed to fetch grouping policies", zap.Error(err), mzap.ObjRef("organization_ref", orgRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := rm.fetchRolesFromPolicies(groupingPolicies, orgRef)
|
||||
|
||||
rm.logger.Info("Retrieved roles for organization", mzap.ObjRef("organization_ref", orgRef), zap.Int("count", len(roles)))
|
||||
return roles, nil
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package serializationimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/auth/anyobject"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// PolicySerializer implements CasbinSerializer for Permission.
|
||||
type PolicySerializer struct{}
|
||||
|
||||
// Serialize converts a Permission object into a Casbin policy.
|
||||
func (s *PolicySerializer) Serialize(entity *model.RolePolicy) ([]any, error) {
|
||||
if entity.RoleDescriptionRef.IsZero() ||
|
||||
entity.OrganizationRef.IsZero() ||
|
||||
entity.DescriptionRef.IsZero() || // Ensure permissionRef is valid
|
||||
entity.Effect.Action == "" || // Ensure action is not empty
|
||||
entity.Effect.Effect == "" { // Ensure effect (eft) is not empty
|
||||
return nil, merrors.InvalidArgument("permission contains invalid object references or missing fields")
|
||||
}
|
||||
|
||||
objectRef := anyobject.ID
|
||||
if entity.ObjectRef != nil {
|
||||
objectRef = entity.ObjectRef.Hex()
|
||||
}
|
||||
|
||||
return []any{
|
||||
entity.RoleDescriptionRef.Hex(), // Maps to p.roleRef
|
||||
entity.OrganizationRef.Hex(), // Maps to p.organizationRef
|
||||
entity.DescriptionRef.Hex(), // Maps to p.permissionRef
|
||||
objectRef, // Maps to p.objectRef (wildcard if empty)
|
||||
string(entity.Effect.Action), // Maps to p.action
|
||||
string(entity.Effect.Effect), // Maps to p.eft
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Deserialize converts a Casbin policy into a Permission object.
|
||||
func (s *PolicySerializer) Deserialize(policy []string) (*model.RolePolicy, error) {
|
||||
if len(policy) != 6 { // Ensure policy has the correct number of fields
|
||||
return nil, merrors.Internal("invalid policy format")
|
||||
}
|
||||
|
||||
roleRef, err := primitive.ObjectIDFromHex(policy[0])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid roleRef in policy")
|
||||
}
|
||||
|
||||
organizationRef, err := primitive.ObjectIDFromHex(policy[1])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid organizationRef in policy")
|
||||
}
|
||||
|
||||
permissionRef, err := primitive.ObjectIDFromHex(policy[2])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid permissionRef in policy")
|
||||
}
|
||||
|
||||
// Handle wildcard for ObjectRef
|
||||
var objectRef *primitive.ObjectID
|
||||
if policy[3] != anyobject.ID {
|
||||
ref, err := primitive.ObjectIDFromHex(policy[3])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid objectRef in policy")
|
||||
}
|
||||
objectRef = &ref
|
||||
}
|
||||
|
||||
return &model.RolePolicy{
|
||||
RoleDescriptionRef: roleRef,
|
||||
Policy: model.Policy{
|
||||
OrganizationRef: organizationRef,
|
||||
DescriptionRef: permissionRef,
|
||||
ObjectRef: objectRef,
|
||||
Effect: model.ActionEffect{
|
||||
Action: model.Action(policy[4]),
|
||||
Effect: model.Effect(policy[5]),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
57
api/pkg/auth/internal/casbin/serialization/internal/role.go
Normal file
57
api/pkg/auth/internal/casbin/serialization/internal/role.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package serializationimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// RoleSerializer implements CasbinSerializer for Role.
|
||||
type RoleSerializer struct{}
|
||||
|
||||
// Serialize converts a Role object into a Casbin grouping policy.
|
||||
func (s *RoleSerializer) Serialize(entity *model.Role) ([]any, error) {
|
||||
// Validate required fields
|
||||
if entity.AccountRef.IsZero() || entity.DescriptionRef.IsZero() || entity.OrganizationRef.IsZero() {
|
||||
return nil, merrors.InvalidArgument("role contains invalid object references")
|
||||
}
|
||||
|
||||
return []any{
|
||||
entity.AccountRef.Hex(), // Maps to g(_, _, _) accountRef
|
||||
entity.DescriptionRef.Hex(), // Maps to g(_, _, _) roleRef
|
||||
entity.OrganizationRef.Hex(), // Maps to g(_, _, _) organizationRef
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Deserialize converts a Casbin grouping policy into a Role object.
|
||||
func (s *RoleSerializer) Deserialize(policy []string) (*model.Role, error) {
|
||||
// Ensure the policy has exactly 3 fields
|
||||
if len(policy) != 3 {
|
||||
return nil, merrors.Internal("invalid grouping policy format")
|
||||
}
|
||||
|
||||
// Parse accountRef
|
||||
accountRef, err := primitive.ObjectIDFromHex(policy[0])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid accountRef in grouping policy")
|
||||
}
|
||||
|
||||
// Parse roleDescriptionRef (roleRef)
|
||||
roleDescriptionRef, err := primitive.ObjectIDFromHex(policy[1])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid roleRef in grouping policy")
|
||||
}
|
||||
|
||||
// Parse organizationRef
|
||||
organizationRef, err := primitive.ObjectIDFromHex(policy[2])
|
||||
if err != nil {
|
||||
return nil, merrors.InvalidArgument("invalid organizationRef in grouping policy")
|
||||
}
|
||||
|
||||
// Return the constructed Role object
|
||||
return &model.Role{
|
||||
AccountRef: accountRef,
|
||||
DescriptionRef: roleDescriptionRef,
|
||||
OrganizationRef: organizationRef,
|
||||
}, nil
|
||||
}
|
||||
12
api/pkg/auth/internal/casbin/serialization/policy.go
Normal file
12
api/pkg/auth/internal/casbin/serialization/policy.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package serialization
|
||||
|
||||
import (
|
||||
serializationimp "github.com/tech/sendico/pkg/auth/internal/casbin/serialization/internal"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
type Policy = CasbinSerializer[model.RolePolicy]
|
||||
|
||||
func NewPolicySerializer() Policy {
|
||||
return &serializationimp.PolicySerializer{}
|
||||
}
|
||||
12
api/pkg/auth/internal/casbin/serialization/role.go
Normal file
12
api/pkg/auth/internal/casbin/serialization/role.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package serialization
|
||||
|
||||
import (
|
||||
serializationimp "github.com/tech/sendico/pkg/auth/internal/casbin/serialization/internal"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
type Role = CasbinSerializer[model.Role]
|
||||
|
||||
func NewRoleSerializer() Role {
|
||||
return &serializationimp.RoleSerializer{}
|
||||
}
|
||||
10
api/pkg/auth/internal/casbin/serialization/serializer.go
Normal file
10
api/pkg/auth/internal/casbin/serialization/serializer.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package serialization
|
||||
|
||||
// CasbinSerializer defines methods for serializing and deserializing any Casbin-compatible entity.
|
||||
type CasbinSerializer[T any] interface {
|
||||
// Serialize converts an entity (Role or Permission) into a Casbin policy.
|
||||
Serialize(entity *T) ([]any, error)
|
||||
|
||||
// Deserialize converts a Casbin policy into an entity (Role or Permission).
|
||||
Deserialize(policy []string) (*T, error)
|
||||
}
|
||||
151
api/pkg/auth/internal/native/db/policies.go
Normal file
151
api/pkg/auth/internal/native/db/policies.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type PermissionsDBImp struct {
|
||||
template.DBImp[*nstructures.PolicyAssignment]
|
||||
}
|
||||
|
||||
func (db *PermissionsDBImp) Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
return mutil.GetObjects[nstructures.PolicyAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Query().And(
|
||||
repository.Filter("policy.organizationRef", object.GetOrganizationRef()),
|
||||
repository.Filter("policy.descriptionRef", object.GetPermissionRef()),
|
||||
repository.Filter("policy.effect.action", action),
|
||||
repository.Query().Or(
|
||||
repository.Filter("policy.objectRef", *object.GetID()),
|
||||
repository.Filter("policy.objectRef", nil),
|
||||
),
|
||||
),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func (db *PermissionsDBImp) PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
return mutil.GetObjects[nstructures.PolicyAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Query().And(
|
||||
repository.Filter("roleRef", roleRef),
|
||||
repository.Filter("policy.descriptionRef", permissionRef),
|
||||
repository.Filter("policy.effect.action", action),
|
||||
),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func (db *PermissionsDBImp) Remove(ctx context.Context, policy *model.RolePolicy) error {
|
||||
objRefFilter := repository.Query().Or(
|
||||
repository.Filter("policy.objectRef", nil),
|
||||
repository.Filter("policy.objectRef", primitive.NilObjectID),
|
||||
)
|
||||
if policy.ObjectRef != nil {
|
||||
objRefFilter = repository.Filter("policy.objectRef", *policy.ObjectRef)
|
||||
}
|
||||
return db.Repository.DeleteMany(
|
||||
ctx,
|
||||
repository.Query().And(
|
||||
repository.Filter("roleRef", policy.RoleDescriptionRef),
|
||||
repository.Filter("policy.organizationRef", policy.OrganizationRef),
|
||||
repository.Filter("policy.descriptionRef", policy.DescriptionRef),
|
||||
objRefFilter,
|
||||
repository.Filter("policy.effect.action", policy.Effect.Action),
|
||||
repository.Filter("policy.effect.effect", policy.Effect.Effect),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (db *PermissionsDBImp) PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
|
||||
return mutil.GetObjects[nstructures.PolicyAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Filter("roleRef", roleRef),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func (db *PermissionsDBImp) PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
if len(roleRefs) == 0 {
|
||||
db.Logger.Debug("Empty role references list provided, returning empty resposnse")
|
||||
return []nstructures.PolicyAssignment{}, nil
|
||||
}
|
||||
return mutil.GetObjects[nstructures.PolicyAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Query().And(
|
||||
repository.Query().In(repository.Field("roleRef"), roleRefs),
|
||||
repository.Filter("policy.effect.action", action),
|
||||
),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func NewPoliciesDB(logger mlogger.Logger, db *mongo.Database) (*PermissionsDBImp, error) {
|
||||
p := &PermissionsDBImp{
|
||||
DBImp: *template.Create[*nstructures.PolicyAssignment](logger, mservice.PolicyAssignements, db),
|
||||
}
|
||||
|
||||
// faster
|
||||
// harder
|
||||
// index
|
||||
policiesQueryIndex := &ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "policy.organizationRef", Sort: ri.Asc},
|
||||
{Field: "policy.descriptionRef", Sort: ri.Asc},
|
||||
{Field: "policy.effect.action", Sort: ri.Asc},
|
||||
{Field: "policy.objectRef", Sort: ri.Asc},
|
||||
},
|
||||
}
|
||||
if err := p.DBImp.Repository.CreateIndex(policiesQueryIndex); err != nil {
|
||||
p.Logger.Warn("Failed to prepare policies query index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roleBasedQueriesIndex := &ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "roleRef", Sort: ri.Asc},
|
||||
{Field: "policy.effect.action", Sort: ri.Asc},
|
||||
},
|
||||
}
|
||||
if err := p.DBImp.Repository.CreateIndex(roleBasedQueriesIndex); err != nil {
|
||||
p.Logger.Warn("Failed to prepare role based query index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uniquePolicyConstaint := &ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "policy.organizationRef", Sort: ri.Asc},
|
||||
{Field: "roleRef", Sort: ri.Asc},
|
||||
{Field: "policy.descriptionRef", Sort: ri.Asc},
|
||||
{Field: "policy.effect.action", Sort: ri.Asc},
|
||||
{Field: "policy.objectRef", Sort: ri.Asc},
|
||||
},
|
||||
Unique: true,
|
||||
}
|
||||
if err := p.DBImp.Repository.CreateIndex(uniquePolicyConstaint); err != nil {
|
||||
p.Logger.Warn("Failed to unique policy assignment index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
99
api/pkg/auth/internal/native/db/roles.go
Normal file
99
api/pkg/auth/internal/native/db/roles.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RolesDBImp struct {
|
||||
template.DBImp[*nstructures.RoleAssignment]
|
||||
}
|
||||
|
||||
func (db *RolesDBImp) Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
|
||||
return mutil.GetObjects[nstructures.RoleAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Query().And(
|
||||
repository.Filter("role.accountRef", accountRef),
|
||||
repository.Filter("role.organizationRef", organizationRef),
|
||||
),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func (db *RolesDBImp) RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
|
||||
return mutil.GetObjects[nstructures.RoleAssignment](
|
||||
ctx,
|
||||
db.Logger,
|
||||
repository.Query().And(
|
||||
repository.Filter("role.organizationRef", organizationRef),
|
||||
),
|
||||
nil,
|
||||
db.Repository,
|
||||
)
|
||||
}
|
||||
|
||||
func (db *RolesDBImp) DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error {
|
||||
return db.DeleteMany(
|
||||
ctx,
|
||||
repository.Query().And(
|
||||
repository.Filter("role.descriptionRef", roleRef),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (db *RolesDBImp) RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error {
|
||||
return db.DeleteMany(
|
||||
ctx,
|
||||
repository.Query().And(
|
||||
repository.Filter("role.accountRef", accountRef),
|
||||
repository.Filter("role.organizationRef", organizationRef),
|
||||
repository.Filter("role.descriptionRef", roleRef),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func NewRolesDB(logger mlogger.Logger, db *mongo.Database) (*RolesDBImp, error) {
|
||||
p := &RolesDBImp{
|
||||
DBImp: *template.Create[*nstructures.RoleAssignment](logger, "role_assignments", db),
|
||||
}
|
||||
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "role.organizationRef", Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
p.Logger.Warn("Failed to prepare venue index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "role.descriptionRef", Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
p.Logger.Warn("Failed to prepare role description index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uniqueRoleConstaint := &ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "role.organizationRef", Sort: ri.Asc},
|
||||
{Field: "role.accountRef", Sort: ri.Asc},
|
||||
{Field: "role.descriptionRef", Sort: ri.Asc},
|
||||
},
|
||||
Unique: true,
|
||||
}
|
||||
if err := p.DBImp.Repository.CreateIndex(uniqueRoleConstaint); err != nil {
|
||||
p.Logger.Warn("Failed to prepare role assignment index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
27
api/pkg/auth/internal/native/dbpolicies.go
Normal file
27
api/pkg/auth/internal/native/dbpolicies.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/db"
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type PoliciesDB interface {
|
||||
template.DB[*nstructures.PolicyAssignment]
|
||||
// plenty of interfaces for performance reasons
|
||||
Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error)
|
||||
PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error)
|
||||
PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error)
|
||||
PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error)
|
||||
Remove(ctx context.Context, policy *model.RolePolicy) error
|
||||
}
|
||||
|
||||
func NewPoliciesDBDB(logger mlogger.Logger, conn *mongo.Database) (PoliciesDB, error) {
|
||||
return db.NewPoliciesDB(logger, conn)
|
||||
}
|
||||
24
api/pkg/auth/internal/native/dbroles.go
Normal file
24
api/pkg/auth/internal/native/dbroles.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/db"
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type RolesDB interface {
|
||||
template.DB[*nstructures.RoleAssignment]
|
||||
Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error)
|
||||
RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error)
|
||||
RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error
|
||||
DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error
|
||||
}
|
||||
|
||||
func NewRolesDB(logger mlogger.Logger, conn *mongo.Database) (RolesDB, error) {
|
||||
return db.NewRolesDB(logger, conn)
|
||||
}
|
||||
256
api/pkg/auth/internal/native/enforcer.go
Normal file
256
api/pkg/auth/internal/native/enforcer.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Enforcer struct {
|
||||
logger mlogger.Logger
|
||||
pdb PoliciesDB
|
||||
rdb RolesDB
|
||||
}
|
||||
|
||||
func NewEnforcer(
|
||||
logger mlogger.Logger,
|
||||
db *mongo.Database,
|
||||
) (*Enforcer, error) {
|
||||
e := &Enforcer{logger: logger.Named("enforcer")}
|
||||
|
||||
var err error
|
||||
if e.pdb, err = NewPoliciesDBDB(e.logger, db); err != nil {
|
||||
e.logger.Warn("Failed to create permission assignments database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if e.rdb, err = NewRolesDB(e.logger, db); err != nil {
|
||||
e.logger.Warn("Failed to create role assignments database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Native enforcer created")
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// Enforce checks if a user has the specified action permission on an object within a domain.
|
||||
func (n *Enforcer) Enforce(
|
||||
ctx context.Context,
|
||||
permissionRef, accountRef, organizationRef, objectRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (bool, error) {
|
||||
roleAssignments, err := n.rdb.Roles(ctx, accountRef, organizationRef)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
n.logger.Debug("No roles defined for account", mzap.ObjRef("account_ref", accountRef))
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
n.logger.Warn("Failed to fetch roles while checking permissions", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object", objectRef), zap.String("action", string(action)))
|
||||
return false, err
|
||||
}
|
||||
if len(roleAssignments) == 0 {
|
||||
n.logger.Warn("No roles found for account", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return false, merrors.Internal("No roles found for account " + accountRef.Hex())
|
||||
}
|
||||
allowFound := false // Track if any allow is found across roles
|
||||
|
||||
for _, roleAssignment := range roleAssignments {
|
||||
policies, err := n.pdb.PoliciesForPermissionAction(ctx, roleAssignment.DescriptionRef, permissionRef, action)
|
||||
if err != nil && !errors.Is(err, merrors.ErrNoData) {
|
||||
n.logger.Warn("Failed to fetch permissions", zap.Error(err), mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, permission := range policies {
|
||||
if permission.Effect.Effect == model.EffectDeny {
|
||||
n.logger.Debug("Found denying policy", mzap.ObjRef("account", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
return false, nil // Deny takes precedence immediately
|
||||
}
|
||||
|
||||
if permission.Effect.Effect == model.EffectAllow {
|
||||
n.logger.Debug("Allowing policy found", mzap.ObjRef("account", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
allowFound = true // At least one allow found
|
||||
} else {
|
||||
n.logger.Warn("Corrupted policy", mzap.StorableRef(&permission))
|
||||
return false, merrors.Internal("Corrupted action effect data for permissions entry " + permission.ID.Hex() + ": " + string(permission.Effect.Effect))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final decision based on whether any allow was found
|
||||
if allowFound {
|
||||
return true, nil // At least one allow and no deny
|
||||
}
|
||||
|
||||
n.logger.Debug("No allowing policy found", mzap.ObjRef("account", accountRef),
|
||||
mzap.ObjRef("organization_ref", organizationRef), mzap.ObjRef("permission_ref", permissionRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.String("action", string(action)))
|
||||
|
||||
return false, nil // No allow found, default deny
|
||||
}
|
||||
|
||||
// EnforceBatch checks a user’s permission for multiple objects at once.
|
||||
// It returns a map from objectRef -> boolean indicating whether access is granted.
|
||||
func (n *Enforcer) EnforceBatch(
|
||||
ctx context.Context,
|
||||
objectRefs []model.PermissionBoundStorable,
|
||||
accountRef primitive.ObjectID,
|
||||
action model.Action,
|
||||
) (map[primitive.ObjectID]bool, error) {
|
||||
results := make(map[primitive.ObjectID]bool, len(objectRefs))
|
||||
|
||||
// Group objectRefs by organizationRef.
|
||||
objectsByVenue := make(map[primitive.ObjectID][]model.PermissionBoundStorable)
|
||||
for _, obj := range objectRefs {
|
||||
organizationRef := obj.GetOrganizationRef()
|
||||
objectsByVenue[organizationRef] = append(objectsByVenue[organizationRef], obj)
|
||||
}
|
||||
|
||||
// Process each venue group separately.
|
||||
for organizationRef, objs := range objectsByVenue {
|
||||
// 1. Fetch roles once for this account and venue.
|
||||
roles, err := n.rdb.Roles(ctx, accountRef, organizationRef)
|
||||
if err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
n.logger.Debug("No roles defined for account", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
// With no roles, mark all objects in this venue as denied.
|
||||
for _, obj := range objs {
|
||||
results[*obj.GetID()] = false
|
||||
}
|
||||
// Continue to next venue
|
||||
continue
|
||||
}
|
||||
n.logger.Warn("Failed to fetch roles", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Extract role description references
|
||||
var roleRefs []primitive.ObjectID
|
||||
for _, role := range roles {
|
||||
roleRefs = append(roleRefs, role.DescriptionRef)
|
||||
}
|
||||
|
||||
// 3. Fetch all policies for these roles and the given action in one call.
|
||||
allPolicies, err := n.pdb.PoliciesForRoles(ctx, roleRefs, action)
|
||||
if err != nil {
|
||||
n.logger.Warn("Failed to fetch policies", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Build a lookup map keyed by PermissionRef.
|
||||
policyMap := make(map[primitive.ObjectID][]nstructures.PolicyAssignment)
|
||||
for _, policy := range allPolicies {
|
||||
policyMap[policy.DescriptionRef] = append(policyMap[policy.DescriptionRef], policy)
|
||||
}
|
||||
|
||||
// 5. Evaluate permissions for each object in this venue group.
|
||||
for _, obj := range objs {
|
||||
permRef := obj.GetPermissionRef()
|
||||
allow := false
|
||||
if policies, ok := policyMap[permRef]; ok {
|
||||
for _, policy := range policies {
|
||||
// Deny takes precedence.
|
||||
if policy.Effect.Effect == model.EffectDeny {
|
||||
allow = false
|
||||
break
|
||||
}
|
||||
if policy.Effect.Effect == model.EffectAllow {
|
||||
allow = true
|
||||
// Continue checking in case a deny exists among policies.
|
||||
} else {
|
||||
// should never get here
|
||||
return nil, merrors.Internal("Corrupted permissions effect in policy assignment '" + policy.GetID().Hex() + "': " + string(policy.Effect.Effect))
|
||||
}
|
||||
}
|
||||
}
|
||||
results[*obj.GetID()] = allow
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetRoles retrieves all roles assigned to the user within the domain.
|
||||
func (n *Enforcer) GetRoles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, error) {
|
||||
n.logger.Debug("Fetching roles for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
ra, err := n.rdb.Roles(ctx, accountRef, organizationRef)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
n.logger.Debug("No roles assigned to user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return []model.Role{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
n.logger.Warn("Failed to fetch roles", zap.Error(err), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := make([]model.Role, len(ra))
|
||||
for i, roleAssignement := range ra {
|
||||
roles[i] = roleAssignement.Role
|
||||
}
|
||||
|
||||
n.logger.Debug("Fetched roles", zap.Int("roles_count", len(roles)))
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
func (n *Enforcer) Reload() error {
|
||||
n.logger.Info("Policies reloaded") // do nothing actually
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPermissions retrieves all effective policies for the user within the domain.
|
||||
func (n *Enforcer) GetPermissions(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
|
||||
n.logger.Debug("Fetching policies for user", mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
|
||||
roles, err := n.GetRoles(ctx, accountRef, organizationRef)
|
||||
if err != nil {
|
||||
n.logger.Warn("Failed to get roles", zap.Error(err))
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
uniquePermissions := make(map[primitive.ObjectID]model.Permission)
|
||||
for _, role := range roles {
|
||||
perms, err := n.pdb.PoliciesForRole(ctx, role.DescriptionRef)
|
||||
if err != nil {
|
||||
n.logger.Warn("Failed to get policies for role", zap.Error(err), mzap.ObjRef("role_ref", role.DescriptionRef))
|
||||
continue
|
||||
}
|
||||
n.logger.Debug("Policies fetched for role", mzap.ObjRef("role_ref", role.DescriptionRef), zap.Int("count", len(perms)))
|
||||
for _, p := range perms {
|
||||
uniquePermissions[*p.GetID()] = model.Permission{
|
||||
RolePolicy: model.RolePolicy{
|
||||
Policy: p.Policy,
|
||||
RoleDescriptionRef: p.RoleRef,
|
||||
},
|
||||
AccountRef: accountRef,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
permissionsSlice := make([]model.Permission, 0, len(uniquePermissions))
|
||||
for _, permission := range uniquePermissions {
|
||||
permissionsSlice = append(permissionsSlice, permission)
|
||||
}
|
||||
|
||||
n.logger.Debug("Policies fetched successfully", zap.Int("count", len(permissionsSlice)))
|
||||
return roles, permissionsSlice, nil
|
||||
}
|
||||
747
api/pkg/auth/internal/native/enforcer_test.go
Normal file
747
api/pkg/auth/internal/native/enforcer_test.go
Normal file
@@ -0,0 +1,747 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
factory "github.com/tech/sendico/pkg/mlogger/factory"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type MockPoliciesDB struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) PoliciesForPermissionAction(ctx context.Context, roleRef, permissionRef primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, roleRef, permissionRef, action)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) PoliciesForRole(ctx context.Context, roleRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, roleRef)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) PoliciesForRoles(ctx context.Context, roleRefs []primitive.ObjectID, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, roleRefs, action)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Policies(ctx context.Context, object model.PermissionBoundStorable, action model.Action) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, object, action)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Remove(ctx context.Context, policy *model.RolePolicy) error {
|
||||
args := m.Called(ctx, policy)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Template DB methods - implement as needed for testing
|
||||
func (m *MockPoliciesDB) Create(ctx context.Context, assignment *nstructures.PolicyAssignment) error {
|
||||
args := m.Called(ctx, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Get(ctx context.Context, id primitive.ObjectID, assignment *nstructures.PolicyAssignment) error {
|
||||
args := m.Called(ctx, id, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Update(ctx context.Context, assignment *nstructures.PolicyAssignment) error {
|
||||
args := m.Called(ctx, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
args := m.Called(ctx, objectRef, patch)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Delete(ctx context.Context, id primitive.ObjectID) error {
|
||||
args := m.Called(ctx, id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) DeleteMany(ctx context.Context, query builder.Query) error {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) ListPermissionBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, accountRef, organizationRef)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) ListIDs(ctx context.Context, query interface{}) ([]primitive.ObjectID, error) {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Get(0).([]primitive.ObjectID), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) FindOne(ctx context.Context, query builder.Query, assignment *nstructures.PolicyAssignment) error {
|
||||
args := m.Called(ctx, query, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) List(ctx context.Context, query builder.Query) ([]nstructures.PolicyAssignment, error) {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Get(0).([]nstructures.PolicyAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) Name() string {
|
||||
return "mock_policies"
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) DeleteCascade(ctx context.Context, id primitive.ObjectID) error {
|
||||
args := m.Called(ctx, id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPoliciesDB) InsertMany(ctx context.Context, objects []*nstructures.PolicyAssignment) error {
|
||||
args := m.Called(ctx, objects)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockRolesDB struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Roles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
|
||||
args := m.Called(ctx, accountRef, organizationRef)
|
||||
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) RolesForVenue(ctx context.Context, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
|
||||
args := m.Called(ctx, organizationRef)
|
||||
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) RemoveRole(ctx context.Context, roleRef, organizationRef, accountRef primitive.ObjectID) error {
|
||||
args := m.Called(ctx, roleRef, organizationRef, accountRef)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) DeleteRole(ctx context.Context, roleRef primitive.ObjectID) error {
|
||||
args := m.Called(ctx, roleRef)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Template DB methods - implement as needed for testing
|
||||
func (m *MockRolesDB) Create(ctx context.Context, assignment *nstructures.RoleAssignment) error {
|
||||
args := m.Called(ctx, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Get(ctx context.Context, id primitive.ObjectID, assignment *nstructures.RoleAssignment) error {
|
||||
args := m.Called(ctx, id, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Update(ctx context.Context, assignment *nstructures.RoleAssignment) error {
|
||||
args := m.Called(ctx, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
args := m.Called(ctx, objectRef, patch)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Delete(ctx context.Context, id primitive.ObjectID) error {
|
||||
args := m.Called(ctx, id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) DeleteMany(ctx context.Context, query builder.Query) error {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) ListPermissionBound(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]nstructures.RoleAssignment, error) {
|
||||
args := m.Called(ctx, accountRef, organizationRef)
|
||||
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) ListIDs(ctx context.Context, query interface{}) ([]primitive.ObjectID, error) {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Get(0).([]primitive.ObjectID), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) FindOne(ctx context.Context, query builder.Query, assignment *nstructures.RoleAssignment) error {
|
||||
args := m.Called(ctx, query, assignment)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) List(ctx context.Context, query builder.Query) ([]nstructures.RoleAssignment, error) {
|
||||
args := m.Called(ctx, query)
|
||||
return args.Get(0).([]nstructures.RoleAssignment), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) Name() string {
|
||||
return "mock_roles"
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) DeleteCascade(ctx context.Context, id primitive.ObjectID) error {
|
||||
args := m.Called(ctx, id)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRolesDB) InsertMany(ctx context.Context, objects []*nstructures.RoleAssignment) error {
|
||||
args := m.Called(ctx, objects)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func createTestObjectID() primitive.ObjectID {
|
||||
return primitive.NewObjectID()
|
||||
}
|
||||
|
||||
func createTestRoleAssignment(roleRef, accountRef, organizationRef primitive.ObjectID) nstructures.RoleAssignment {
|
||||
return nstructures.RoleAssignment{
|
||||
Role: model.Role{
|
||||
AccountRef: accountRef,
|
||||
DescriptionRef: roleRef,
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPolicyAssignment(roleRef primitive.ObjectID, action model.Action, effect model.Effect, organizationRef, descriptionRef primitive.ObjectID, objectRef *primitive.ObjectID) nstructures.PolicyAssignment {
|
||||
return nstructures.PolicyAssignment{
|
||||
Policy: model.Policy{
|
||||
OrganizationRef: organizationRef,
|
||||
DescriptionRef: descriptionRef,
|
||||
ObjectRef: objectRef,
|
||||
Effect: model.ActionEffect{
|
||||
Action: action,
|
||||
Effect: effect,
|
||||
},
|
||||
},
|
||||
RoleRef: roleRef,
|
||||
}
|
||||
}
|
||||
|
||||
func createTestEnforcer(pdb PoliciesDB, rdb RolesDB) *Enforcer {
|
||||
logger := factory.NewLogger(true)
|
||||
enforcer := &Enforcer{
|
||||
logger: logger.Named("test"),
|
||||
pdb: pdb,
|
||||
rdb: rdb,
|
||||
}
|
||||
return enforcer
|
||||
}
|
||||
|
||||
func TestEnforcer_Enforce(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test data
|
||||
accountRef := createTestObjectID()
|
||||
organizationRef := createTestObjectID()
|
||||
permissionRef := createTestObjectID()
|
||||
objectRef := createTestObjectID()
|
||||
roleRef := createTestObjectID()
|
||||
|
||||
t.Run("Allow_SingleRole_SinglePolicy", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock policy assignment with ALLOW effect
|
||||
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
|
||||
|
||||
// Create enforcer
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.True(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Deny_SingleRole_SinglePolicy", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock policy assignment with DENY effect
|
||||
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("DenyTakesPrecedence_MultipleRoles", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
role1Ref := createTestObjectID()
|
||||
role2Ref := createTestObjectID()
|
||||
|
||||
// Mock multiple role assignments
|
||||
roleAssignment1 := createTestRoleAssignment(role1Ref, accountRef, organizationRef)
|
||||
roleAssignment2 := createTestRoleAssignment(role2Ref, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment1, roleAssignment2}, nil)
|
||||
|
||||
// First role has ALLOW policy
|
||||
allowPolicy := createTestPolicyAssignment(role1Ref, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, role1Ref, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{allowPolicy}, nil)
|
||||
|
||||
// Second role has DENY policy - should take precedence
|
||||
denyPolicy := createTestPolicyAssignment(role2Ref, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, role2Ref, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{denyPolicy}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify - DENY should take precedence
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("NoRoles_ReturnsFalse", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock no roles found
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("EmptyRoles_ReturnsError", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock empty roles list (not NoData error)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.Error(t, err)
|
||||
assert.False(t, allowed)
|
||||
assert.Contains(t, err.Error(), "No roles found for account")
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("DatabaseError_RolesDB", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock database error
|
||||
dbError := errors.New("database connection failed")
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.Error(t, err)
|
||||
assert.False(t, allowed)
|
||||
assert.Equal(t, dbError, err)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("DatabaseError_PoliciesDB", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock database error in policies
|
||||
dbError := errors.New("policies database error")
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{}, dbError)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.Error(t, err)
|
||||
assert.False(t, allowed)
|
||||
assert.Equal(t, dbError, err)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("NoPolicies_ReturnsFalse", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock no policies found
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{}, merrors.ErrNoData)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("CorruptedPolicy_ReturnsError", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock corrupted policy with invalid effect
|
||||
corruptedPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, "invalid_effect", organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{corruptedPolicy}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.Error(t, err)
|
||||
assert.False(t, allowed)
|
||||
assert.Contains(t, err.Error(), "Corrupted action effect data")
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
// Mock implementation for PermissionBoundStorable
|
||||
type MockPermissionBoundStorable struct {
|
||||
id primitive.ObjectID
|
||||
permissionRef primitive.ObjectID
|
||||
organizationRef primitive.ObjectID
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) GetID() *primitive.ObjectID {
|
||||
return &m.id
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) GetPermissionRef() primitive.ObjectID {
|
||||
return m.permissionRef
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) GetOrganizationRef() primitive.ObjectID {
|
||||
return m.organizationRef
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) Collection() string {
|
||||
return "test_objects"
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) SetID(objID primitive.ObjectID) {
|
||||
m.id = objID
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) Update() {
|
||||
// Do nothing for mock
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) SetPermissionRef(permissionRef primitive.ObjectID) {
|
||||
m.permissionRef = permissionRef
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) SetOrganizationRef(organizationRef primitive.ObjectID) {
|
||||
m.organizationRef = organizationRef
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) IsArchived() bool {
|
||||
return false // Default to not archived for testing
|
||||
}
|
||||
|
||||
func (m *MockPermissionBoundStorable) SetArchived(archived bool) {
|
||||
// No-op for testing
|
||||
}
|
||||
|
||||
func TestEnforcer_EnforceBatch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test data
|
||||
accountRef := createTestObjectID()
|
||||
organizationRef := createTestObjectID()
|
||||
permissionRef := createTestObjectID()
|
||||
roleRef := createTestObjectID()
|
||||
|
||||
// Create test objects
|
||||
object1 := &MockPermissionBoundStorable{
|
||||
id: createTestObjectID(),
|
||||
permissionRef: permissionRef,
|
||||
organizationRef: organizationRef,
|
||||
}
|
||||
object2 := &MockPermissionBoundStorable{
|
||||
id: createTestObjectID(),
|
||||
permissionRef: permissionRef,
|
||||
organizationRef: organizationRef,
|
||||
}
|
||||
|
||||
t.Run("BatchEnforce_MultipleObjects_SameVenue", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock policy assignment with ALLOW effect
|
||||
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, nil)
|
||||
mockPDB.On("PoliciesForRoles", ctx, []primitive.ObjectID{roleRef}, model.ActionRead).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute batch enforcement
|
||||
objects := []model.PermissionBoundStorable{object1, object2}
|
||||
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.True(t, results[object1.id])
|
||||
assert.True(t, results[object2.id])
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BatchEnforce_NoRoles_AllObjectsDenied", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock no roles found
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute batch enforcement
|
||||
objects := []model.PermissionBoundStorable{object1, object2}
|
||||
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.False(t, results[object1.id])
|
||||
assert.False(t, results[object2.id])
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("BatchEnforce_DatabaseError", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock database error
|
||||
dbError := errors.New("database connection failed")
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute batch enforcement
|
||||
objects := []model.PermissionBoundStorable{object1, object2}
|
||||
results, err := enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
|
||||
|
||||
// Verify
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, results)
|
||||
assert.Equal(t, dbError, err)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnforcer_GetRoles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test data
|
||||
accountRef := createTestObjectID()
|
||||
organizationRef := createTestObjectID()
|
||||
roleRef := createTestObjectID()
|
||||
|
||||
t.Run("GetRoles_Success", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
roles, err := enforcer.GetRoles(ctx, accountRef, organizationRef)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, roles, 1)
|
||||
assert.Equal(t, roleRef, roles[0].DescriptionRef)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("GetRoles_NoRoles", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock no roles found
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, merrors.ErrNoData)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
roles, err := enforcer.GetRoles(ctx, accountRef, organizationRef)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, roles, 0)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnforcer_GetPermissions(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test data
|
||||
accountRef := createTestObjectID()
|
||||
organizationRef := createTestObjectID()
|
||||
roleRef := createTestObjectID()
|
||||
|
||||
t.Run("GetPermissions_Success", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock policy assignment
|
||||
policyAssignment := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, createTestObjectID(), nil)
|
||||
mockPDB.On("PoliciesForRole", ctx, roleRef).Return([]nstructures.PolicyAssignment{policyAssignment}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
roles, permissions, err := enforcer.GetPermissions(ctx, accountRef, organizationRef)
|
||||
|
||||
// Verify
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, roles, 1)
|
||||
assert.Len(t, permissions, 1)
|
||||
assert.Equal(t, accountRef, permissions[0].AccountRef)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
// Security-focused test scenarios
|
||||
func TestEnforcer_SecurityScenarios(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test data
|
||||
accountRef := createTestObjectID()
|
||||
organizationRef := createTestObjectID()
|
||||
permissionRef := createTestObjectID()
|
||||
objectRef := createTestObjectID()
|
||||
roleRef := createTestObjectID()
|
||||
|
||||
t.Run("Security_DenyAlwaysWins", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock role assignment
|
||||
roleAssignment := createTestRoleAssignment(roleRef, accountRef, organizationRef)
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{roleAssignment}, nil)
|
||||
|
||||
// Mock multiple policies: both ALLOW and DENY
|
||||
allowPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectAllow, organizationRef, permissionRef, &objectRef)
|
||||
denyPolicy := createTestPolicyAssignment(roleRef, model.ActionRead, model.EffectDeny, organizationRef, permissionRef, &objectRef)
|
||||
mockPDB.On("PoliciesForPermissionAction", ctx, roleRef, permissionRef, model.ActionRead).Return([]nstructures.PolicyAssignment{allowPolicy, denyPolicy}, nil)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify - DENY should always win
|
||||
require.NoError(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
mockPDB.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Security_InvalidObjectID", func(t *testing.T) {
|
||||
mockPDB := &MockPoliciesDB{}
|
||||
mockRDB := &MockRolesDB{}
|
||||
|
||||
// Mock database error for invalid ObjectID
|
||||
dbError := errors.New("invalid ObjectID")
|
||||
mockRDB.On("Roles", ctx, accountRef, organizationRef).Return([]nstructures.RoleAssignment{}, dbError)
|
||||
|
||||
enforcer := createTestEnforcer(mockPDB, mockRDB)
|
||||
|
||||
// Execute with invalid ObjectID
|
||||
allowed, err := enforcer.Enforce(ctx, permissionRef, accountRef, organizationRef, objectRef, model.ActionRead)
|
||||
|
||||
// Verify - should fail securely
|
||||
require.Error(t, err)
|
||||
assert.False(t, allowed)
|
||||
mockRDB.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
// Note: This test provides comprehensive coverage of the native enforcer including:
|
||||
// 1. Basic enforcement logic with deny-takes-precedence
|
||||
// 2. Batch operations for performance
|
||||
// 3. Role and permission retrieval
|
||||
// 4. Security scenarios and edge cases
|
||||
// 5. Error handling and database failures
|
||||
// 6. All critical security paths are tested
|
||||
51
api/pkg/auth/internal/native/manager.go
Normal file
51
api/pkg/auth/internal/native/manager.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/management"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NativeManager implements the auth.Manager interface by aggregating Role and Permission managers.
|
||||
type NativeManager struct {
|
||||
logger mlogger.Logger
|
||||
roleManager management.Role
|
||||
permManager management.Permission
|
||||
}
|
||||
|
||||
// NewManager creates a new CasbinManager with specified domains and role-domain mappings.
|
||||
func NewManager(
|
||||
l mlogger.Logger,
|
||||
pdb policy.DB,
|
||||
rdb role.DB,
|
||||
enforcer *Enforcer,
|
||||
) (*NativeManager, error) {
|
||||
logger := l.Named("manager")
|
||||
|
||||
var pdesc model.PolicyDescription
|
||||
if err := pdb.GetBuiltInPolicy(context.Background(), "roles", &pdesc); err != nil {
|
||||
logger.Warn("Failed to fetch roles permission reference", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &NativeManager{
|
||||
logger: logger,
|
||||
roleManager: NewRoleManager(logger, enforcer, pdesc.ID, rdb),
|
||||
permManager: NewPermissionManager(logger, enforcer),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Permission returns the Permission manager.
|
||||
func (m *NativeManager) Permission() management.Permission {
|
||||
return m.permManager
|
||||
}
|
||||
|
||||
// Role returns the Role manager.
|
||||
func (m *NativeManager) Role() management.Role {
|
||||
return m.roleManager
|
||||
}
|
||||
BIN
api/pkg/auth/internal/native/native.test
Executable file
BIN
api/pkg/auth/internal/native/native.test
Executable file
Binary file not shown.
17
api/pkg/auth/internal/native/nstructures/policies.go
Normal file
17
api/pkg/auth/internal/native/nstructures/policies.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package nstructures
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type PolicyAssignment struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
model.Policy `bson:"policy" json:"policy"`
|
||||
RoleRef primitive.ObjectID `bson:"roleRef" json:"roleRef"`
|
||||
}
|
||||
|
||||
func (*PolicyAssignment) Collection() string {
|
||||
return "permission_assignments"
|
||||
}
|
||||
15
api/pkg/auth/internal/native/nstructures/role.go
Normal file
15
api/pkg/auth/internal/native/nstructures/role.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package nstructures
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
type RoleAssignment struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
model.Role `bson:"role" json:"role"`
|
||||
}
|
||||
|
||||
func (*RoleAssignment) Collection() string {
|
||||
return "role_assignments"
|
||||
}
|
||||
101
api/pkg/auth/internal/native/permission.go
Normal file
101
api/pkg/auth/internal/native/permission.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PermissionManager manages permissions using Casbin.
|
||||
type PermissionManager struct {
|
||||
logger mlogger.Logger
|
||||
enforcer *Enforcer
|
||||
}
|
||||
|
||||
// GrantToRole adds a permission to a role in Casbin.
|
||||
func (m *PermissionManager) GrantToRole(ctx context.Context, policy *model.RolePolicy) error {
|
||||
objRef := "any"
|
||||
if (policy.ObjectRef != nil) && (*policy.ObjectRef != primitive.NilObjectID) {
|
||||
objRef = policy.ObjectRef.Hex()
|
||||
}
|
||||
|
||||
m.logger.Debug("Granting permission to role", mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)),
|
||||
)
|
||||
|
||||
assignment := nstructures.PolicyAssignment{
|
||||
Policy: policy.Policy,
|
||||
RoleRef: policy.RoleDescriptionRef,
|
||||
}
|
||||
if err := m.enforcer.pdb.Create(ctx, &assignment); err != nil {
|
||||
m.logger.Warn("Failed to grant policy", zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeFromRole removes a permission from a role in Casbin.
|
||||
func (m *PermissionManager) RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error {
|
||||
objRef := "*"
|
||||
if policy.ObjectRef != nil {
|
||||
objRef = policy.ObjectRef.Hex()
|
||||
}
|
||||
m.logger.Debug("Revoking permission from role", mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)),
|
||||
)
|
||||
if err := m.enforcer.pdb.Remove(ctx, policy); err != nil {
|
||||
m.logger.Warn("Failed to revoke policy", zap.Error(err), mzap.ObjRef("role_ref", policy.RoleDescriptionRef),
|
||||
mzap.ObjRef("permission_ref", policy.DescriptionRef), zap.String("object_ref", objRef),
|
||||
zap.String("action", string(policy.Effect.Action)), zap.String("effect", string(policy.Effect.Effect)))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPolicies retrieves all policies for a specific role.
|
||||
func (m *PermissionManager) GetPolicies(
|
||||
ctx context.Context,
|
||||
roleRef primitive.ObjectID,
|
||||
) ([]model.RolePolicy, error) {
|
||||
m.logger.Debug("Fetching policies for role", mzap.ObjRef("role_ref", roleRef))
|
||||
|
||||
assinments, err := m.enforcer.pdb.PoliciesForRole(ctx, roleRef)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
m.logger.Debug("No policies found", mzap.ObjRef("role_ref", roleRef))
|
||||
return []model.RolePolicy{}, nil
|
||||
}
|
||||
policies := make([]model.RolePolicy, len(assinments))
|
||||
for i, assinment := range assinments {
|
||||
policies[i] = model.RolePolicy{
|
||||
Policy: assinment.Policy,
|
||||
RoleDescriptionRef: assinment.RoleRef,
|
||||
}
|
||||
}
|
||||
m.logger.Debug("Policies fetched successfully", mzap.ObjRef("role_ref", roleRef), zap.Int("count", len(policies)))
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
// Save persists changes to the Casbin policy store.
|
||||
func (m *PermissionManager) Save() error {
|
||||
m.logger.Info("Policies successfully saved") // do nothing
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPermissionManager(logger mlogger.Logger, enforcer *Enforcer) *PermissionManager {
|
||||
return &PermissionManager{
|
||||
logger: logger.Named("permission"),
|
||||
enforcer: enforcer,
|
||||
}
|
||||
}
|
||||
142
api/pkg/auth/internal/native/role.go
Normal file
142
api/pkg/auth/internal/native/role.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package native
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth/internal/native/nstructures"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleManager manages roles using Casbin.
|
||||
type RoleManager struct {
|
||||
logger mlogger.Logger
|
||||
enforcer *Enforcer
|
||||
rdb role.DB
|
||||
rolePermissionRef primitive.ObjectID
|
||||
}
|
||||
|
||||
// NewRoleManager creates a new RoleManager.
|
||||
func NewRoleManager(logger mlogger.Logger, enforcer *Enforcer, rolePermissionRef primitive.ObjectID, rdb role.DB) *RoleManager {
|
||||
return &RoleManager{
|
||||
logger: logger.Named("role"),
|
||||
enforcer: enforcer,
|
||||
rdb: rdb,
|
||||
rolePermissionRef: rolePermissionRef,
|
||||
}
|
||||
}
|
||||
|
||||
// validateObjectIDs ensures that all provided ObjectIDs are non-zero.
|
||||
func (rm *RoleManager) validateObjectIDs(ids ...primitive.ObjectID) error {
|
||||
for _, id := range ids {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("Object references cannot be zero")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchRolesFromPolicies retrieves and converts policies to roles.
|
||||
func (rm *RoleManager) fetchRolesFromPolicies(roles []nstructures.RoleAssignment, organizationRef primitive.ObjectID) []model.RoleDescription {
|
||||
result := make([]model.RoleDescription, len(roles))
|
||||
for i, role := range roles {
|
||||
result[i] = model.RoleDescription{
|
||||
Base: storable.Base{ID: *role.GetID()},
|
||||
OrganizationRef: organizationRef,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Create creates a new role in an organization.
|
||||
func (rm *RoleManager) Create(ctx context.Context, organizationRef primitive.ObjectID, description *model.Describable) (*model.RoleDescription, error) {
|
||||
if err := rm.validateObjectIDs(organizationRef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
role := &model.RoleDescription{
|
||||
OrganizationRef: organizationRef,
|
||||
Describable: *description,
|
||||
}
|
||||
if err := rm.rdb.Create(ctx, role); err != nil {
|
||||
rm.logger.Warn("Failed to create role", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rm.logger.Info("Role created successfully", mzap.StorableRef(role), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// Assign assigns a role to a user in the given organization.
|
||||
func (rm *RoleManager) Assign(ctx context.Context, role *model.Role) error {
|
||||
if err := rm.validateObjectIDs(role.DescriptionRef, role.AccountRef, role.OrganizationRef); err != nil {
|
||||
return err
|
||||
}
|
||||
assogment := nstructures.RoleAssignment{Role: *role}
|
||||
err := rm.enforcer.rdb.Create(ctx, &assogment)
|
||||
return rm.logPolicyResult("assign", err == nil, err, role.DescriptionRef, role.AccountRef, role.OrganizationRef)
|
||||
}
|
||||
|
||||
// Delete removes a role entirely and cleans up associated Casbin policies.
|
||||
func (rm *RoleManager) Delete(ctx context.Context, roleRef primitive.ObjectID) error {
|
||||
if err := rm.validateObjectIDs(roleRef); err != nil {
|
||||
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rm.rdb.Delete(ctx, roleRef); err != nil {
|
||||
rm.logger.Warn("Failed to delete role", mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rm.enforcer.rdb.DeleteRole(ctx, roleRef); err != nil {
|
||||
rm.logger.Warn("Failed to remove role", zap.Error(err), mzap.ObjRef("role_ref", roleRef))
|
||||
return err
|
||||
}
|
||||
|
||||
rm.logger.Info("Role deleted successfully along with associated policies", mzap.ObjRef("role_ref", roleRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Revoke removes a role from a user.
|
||||
func (rm *RoleManager) Revoke(ctx context.Context, roleRef, accountRef, organizationRef primitive.ObjectID) error {
|
||||
if err := rm.validateObjectIDs(roleRef, accountRef, organizationRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := rm.enforcer.rdb.RemoveRole(ctx, roleRef, organizationRef, accountRef)
|
||||
return rm.logPolicyResult("revoke", err == nil, err, roleRef, accountRef, organizationRef)
|
||||
}
|
||||
|
||||
// logPolicyResult logs results for Assign and Revoke.
|
||||
func (rm *RoleManager) logPolicyResult(action string, result bool, err error, roleRef, accountRef, organizationRef primitive.ObjectID) error {
|
||||
if err != nil {
|
||||
rm.logger.Warn("Failed to "+action+" role", zap.Error(err), mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
msg := "Role " + action + "ed successfully"
|
||||
if !result {
|
||||
msg = "Role already " + action + "ed"
|
||||
}
|
||||
rm.logger.Info(msg, mzap.ObjRef("role_ref", roleRef), mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// List retrieves all roles in an organization or all roles if organizationRef is zero.
|
||||
func (rm *RoleManager) List(ctx context.Context, organizationRef primitive.ObjectID) ([]model.RoleDescription, error) {
|
||||
roles4Venues, err := rm.enforcer.rdb.RolesForVenue(ctx, organizationRef)
|
||||
if err != nil {
|
||||
rm.logger.Warn("Failed to fetch grouping policies", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
roles := rm.fetchRolesFromPolicies(roles4Venues, organizationRef)
|
||||
rm.logger.Info("Retrieved roles for organization", mzap.ObjRef("organization_ref", organizationRef), zap.Int("count", len(roles)))
|
||||
return roles, nil
|
||||
}
|
||||
27
api/pkg/auth/management/permission.go
Normal file
27
api/pkg/auth/management/permission.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type Permission interface {
|
||||
// Grant a permission to a role with an optional object scope and specified effect.
|
||||
// Use primitive.NilObjectID for 'any' objectRef.
|
||||
GrantToRole(ctx context.Context, policy *model.RolePolicy) error
|
||||
|
||||
// Revoke a permission from a role with an optional object scope and specified effect.
|
||||
// Use primitive.NilObjectID for 'any' objectRef.
|
||||
RevokeFromRole(ctx context.Context, policy *model.RolePolicy) error
|
||||
|
||||
// Retrieve all policies assigned to a specific role, including scope and effects.
|
||||
GetPolicies(
|
||||
ctx context.Context,
|
||||
roleRef primitive.ObjectID,
|
||||
) ([]model.RolePolicy, error)
|
||||
|
||||
// Persist any changes made to permissions.
|
||||
Save() error
|
||||
}
|
||||
41
api/pkg/auth/management/role.go
Normal file
41
api/pkg/auth/management/role.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type Role interface {
|
||||
// Create a new role in an organization (returns the created Role with its ID).
|
||||
Create(
|
||||
ctx context.Context,
|
||||
orgRef primitive.ObjectID,
|
||||
description *model.Describable,
|
||||
) (*model.RoleDescription, error)
|
||||
|
||||
// Delete a role entirely. This will cascade and remove all associated
|
||||
Delete(
|
||||
ctx context.Context,
|
||||
roleRef primitive.ObjectID,
|
||||
) error
|
||||
|
||||
// Assign a role to a user in a specific organization.
|
||||
Assign(
|
||||
ctx context.Context,
|
||||
role *model.Role,
|
||||
) error
|
||||
|
||||
// Revoke a role from a user in a specific organization.
|
||||
Revoke(
|
||||
ctx context.Context,
|
||||
roleRef, accountRef, orgRef primitive.ObjectID,
|
||||
) error
|
||||
|
||||
// List all roles in an organization or globally if orgRef is primitive.NilObjectID.
|
||||
List(
|
||||
ctx context.Context,
|
||||
orgRef primitive.ObjectID,
|
||||
) ([]model.RoleDescription, error)
|
||||
}
|
||||
15
api/pkg/auth/manager.go
Normal file
15
api/pkg/auth/manager.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/auth/management"
|
||||
)
|
||||
|
||||
// Manager provides access to domain-aware Permission and Role managers.
|
||||
type Manager interface {
|
||||
// Permission returns a manager that handles permission grants/revokes
|
||||
// for a specific resource type. (You might add domainRef here if desired.)
|
||||
Permission() management.Permission
|
||||
|
||||
// Role returns the domain-aware Role manager.
|
||||
Role() management.Role
|
||||
}
|
||||
14
api/pkg/auth/provider.go
Normal file
14
api/pkg/auth/provider.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Enforcer() Enforcer
|
||||
Manager() Manager
|
||||
GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error)
|
||||
}
|
||||
43
api/pkg/auth/taggable.go
Normal file
43
api/pkg/auth/taggable.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// TaggableDB implements tag operations with permission checking
|
||||
type TaggableDB[T model.PermissionBoundStorable] interface {
|
||||
// AddTag adds a tag to an entity with permission checking
|
||||
AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error
|
||||
// RemoveTagd removes a tags from the collection using organizationRef with permission checking
|
||||
RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error
|
||||
// RemoveTag removes a tag from an entity with permission checking
|
||||
RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error
|
||||
// AddTags adds multiple tags to an entity with permission checking
|
||||
AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error
|
||||
// SetTags sets the tags for an entity with permission checking
|
||||
SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error
|
||||
// RemoveAllTags removes all tags from an entity with permission checking
|
||||
RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error
|
||||
// GetTags gets the tags for an entity with permission checking
|
||||
GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error)
|
||||
// HasTag checks if an entity has a specific tag with permission checking
|
||||
HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error)
|
||||
// FindByTag finds all entities that have a specific tag with permission checking
|
||||
FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]T, error)
|
||||
// FindByTags finds all entities that have any of the specified tags with permission checking
|
||||
FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]T, error)
|
||||
}
|
||||
|
||||
// NewTaggableDBImp creates a new auth.TaggableDB instance
|
||||
func NewTaggableDB[T model.PermissionBoundStorable](
|
||||
dbImp *template.DBImp[T],
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getTaggable func(T) *model.Taggable,
|
||||
) TaggableDB[T] {
|
||||
return newTaggableDBImp(dbImp, enforcer, createEmpty, getTaggable)
|
||||
}
|
||||
302
api/pkg/auth/taggableimp.go
Normal file
302
api/pkg/auth/taggableimp.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// taggableDBImp implements tag operations with permission checking
|
||||
type taggableDBImp[T model.PermissionBoundStorable] struct {
|
||||
dbImp *template.DBImp[T]
|
||||
logger mlogger.Logger
|
||||
enforcer Enforcer
|
||||
createEmpty func() T
|
||||
getTaggable func(T) *model.Taggable
|
||||
}
|
||||
|
||||
func newTaggableDBImp[T model.PermissionBoundStorable](
|
||||
dbImp *template.DBImp[T],
|
||||
enforcer Enforcer,
|
||||
createEmpty func() T,
|
||||
getTaggable func(T) *model.Taggable,
|
||||
) TaggableDB[T] {
|
||||
return &taggableDBImp[T]{
|
||||
dbImp: dbImp,
|
||||
logger: dbImp.Logger.Named("taggable"),
|
||||
enforcer: enforcer,
|
||||
createEmpty: createEmpty,
|
||||
getTaggable: getTaggable,
|
||||
}
|
||||
}
|
||||
|
||||
func (db *taggableDBImp[T]) AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the tag
|
||||
patch := repository.Patch().AddToSet(repository.TagRefsField(), tagRef)
|
||||
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to add tag to object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully added tag to object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *taggableDBImp[T]) removeTag(ctx context.Context, accountRef, targetRef, tagRef primitive.ObjectID, query builder.Query) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObject(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, query); err != nil {
|
||||
db.logger.Debug("Error enforcing permissions for removing tag", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the tag
|
||||
patch := repository.Patch().Pull(repository.TagRefsField(), tagRef)
|
||||
patched, err := db.dbImp.PatchMany(ctx, query, patch)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to remove tag from object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully removed tag from object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("target_ref", targetRef), mzap.ObjRef("tag_ref", tagRef), zap.Int("patched_count", patched))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *taggableDBImp[T]) RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
|
||||
return db.removeTag(ctx, accountRef, primitive.NilObjectID, tagRef, repository.OrgFilter(organizationRef))
|
||||
}
|
||||
|
||||
func (db *taggableDBImp[T]) RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
return db.removeTag(ctx, accountRef, objectRef, tagRef, repository.IDFilter(objectRef))
|
||||
}
|
||||
|
||||
// AddTags adds multiple tags to an entity with permission checking
|
||||
func (db *taggableDBImp[T]) AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the tags one by one using $addToSet to avoid duplicates
|
||||
for _, tagRef := range tagRefs {
|
||||
patch := repository.Patch().AddToSet(repository.TagRefsField(), tagRef)
|
||||
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to add tag to object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully added tags to object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(tagRefs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags sets the tags for an entity with permission checking
|
||||
func (db *taggableDBImp[T]) SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the tags
|
||||
patch := repository.Patch().Set(repository.TagRefsField(), tagRefs)
|
||||
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to set tags for object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully set tags for object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(tagRefs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAllTags removes all tags from an entity with permission checking
|
||||
func (db *taggableDBImp[T]) RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionUpdate, accountRef, objectRef); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove all tags by setting to empty array
|
||||
patch := repository.Patch().Set(repository.TagRefsField(), []primitive.ObjectID{})
|
||||
if err := db.dbImp.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to remove all tags from object", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully removed all tags from object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTags gets the tags for an entity with permission checking
|
||||
func (db *taggableDBImp[T]) GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error) {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the object and extract tags
|
||||
obj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for retrieving tags", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the tags
|
||||
taggable := db.getTaggable(obj)
|
||||
db.logger.Debug("Successfully retrieved tags for object", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), zap.Int("tag_count", len(taggable.TagRefs)))
|
||||
return taggable.TagRefs, nil
|
||||
}
|
||||
|
||||
// HasTag checks if an entity has a specific tag with permission checking
|
||||
func (db *taggableDBImp[T]) HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error) {
|
||||
// Check permissions using enforceObject helper
|
||||
if err := enforceObjectByRef(ctx, db.dbImp, db.enforcer, model.ActionRead, accountRef, objectRef); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Get the object and check if the tag exists
|
||||
obj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for checking tag", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if the tag exists
|
||||
taggable := db.getTaggable(obj)
|
||||
for _, existingTag := range taggable.TagRefs {
|
||||
if existingTag == tagRef {
|
||||
db.logger.Debug("Object has tag", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Debug("Object does not have tag", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("object_ref", objectRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// FindByTag finds all entities that have a specific tag with permission checking
|
||||
func (db *taggableDBImp[T]) FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]T, error) {
|
||||
// Create filter to find objects with the tag
|
||||
filter := repository.Filter(model.TagRefsField, tagRef)
|
||||
|
||||
// Get all objects with the tag using ListPermissionBound
|
||||
objects, err := db.dbImp.ListPermissionBound(ctx, filter)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to get objects with tag", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("tag_ref", tagRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check permissions for all objects using EnforceBatch
|
||||
db.logger.Debug("Checking permissions for objects with tag", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("tag_ref", tagRef), zap.Int("object_count", len(objects)))
|
||||
|
||||
permissions, err := db.enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to check permissions for objects with tag", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("tag_ref", tagRef), zap.Int("object_count", len(objects)))
|
||||
return nil, merrors.Internal("failed to check permissions for objects with tag")
|
||||
}
|
||||
|
||||
// Filter objects based on permissions and decode them
|
||||
var results []T
|
||||
for _, obj := range objects {
|
||||
objID := *obj.GetID()
|
||||
if hasPermission, exists := permissions[objID]; exists && hasPermission {
|
||||
// Decode the object
|
||||
decodedObj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objID, decodedObj); err != nil {
|
||||
db.logger.Warn("Failed to decode object with tag", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objID), mzap.ObjRef("tag_ref", tagRef))
|
||||
continue
|
||||
}
|
||||
results = append(results, decodedObj)
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully found objects with tag", mzap.ObjRef("account_ref", accountRef),
|
||||
mzap.ObjRef("tag_ref", tagRef), zap.Int("total_objects", len(objects)), zap.Int("accessible_objects", len(results)))
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// FindByTags finds all entities that have any of the specified tags with permission checking
|
||||
func (db *taggableDBImp[T]) FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]T, error) {
|
||||
if len(tagRefs) == 0 {
|
||||
return []T{}, nil
|
||||
}
|
||||
|
||||
// Convert []primitive.ObjectID to []any for the In method
|
||||
values := make([]any, len(tagRefs))
|
||||
for i, tagRef := range tagRefs {
|
||||
values[i] = tagRef
|
||||
}
|
||||
|
||||
// Create filter to find objects with any of the tags
|
||||
filter := repository.Query().In(repository.TagRefsField(), values...)
|
||||
|
||||
// Get all objects with any of the tags using ListPermissionBound
|
||||
objects, err := db.dbImp.ListPermissionBound(ctx, filter)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to get objects with tags", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check permissions for all objects using EnforceBatch
|
||||
db.logger.Debug("Checking permissions for objects with tags", mzap.ObjRef("account_ref", accountRef),
|
||||
zap.Int("object_count", len(objects)), zap.Int("tag_count", len(tagRefs)))
|
||||
|
||||
permissions, err := db.enforcer.EnforceBatch(ctx, objects, accountRef, model.ActionRead)
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to check permissions for objects with tags", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.Int("object_count", len(objects)))
|
||||
return nil, merrors.Internal("failed to check permissions for objects with tags")
|
||||
}
|
||||
|
||||
// Filter objects based on permissions and decode them
|
||||
var results []T
|
||||
for _, obj := range objects {
|
||||
objID := *obj.GetID()
|
||||
if hasPermission, exists := permissions[objID]; exists && hasPermission {
|
||||
// Decode the object
|
||||
decodedObj := db.createEmpty()
|
||||
if err := db.dbImp.Get(ctx, objID, decodedObj); err != nil {
|
||||
db.logger.Warn("Failed to decode object with tags", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("object_ref", objID))
|
||||
continue
|
||||
}
|
||||
results = append(results, decodedObj)
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully found objects with tags", mzap.ObjRef("account_ref", accountRef),
|
||||
zap.Int("total_objects", len(objects)), zap.Int("accessible_objects", len(results)), zap.Int("tag_count", len(tagRefs)))
|
||||
return results, nil
|
||||
}
|
||||
21
api/pkg/clock/clock.go
Normal file
21
api/pkg/clock/clock.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package clock
|
||||
|
||||
import "time"
|
||||
|
||||
// Clock exposes basic time operations, primarily for test overrides.
|
||||
type Clock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// System implements Clock using the system wall clock.
|
||||
type System struct{}
|
||||
|
||||
// Now returns the current UTC time.
|
||||
func (System) Now() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
// NewSystem returns a system-backed clock instance.
|
||||
func NewSystem() Clock {
|
||||
return System{}
|
||||
}
|
||||
17
api/pkg/db/account/account.go
Executable file
17
api/pkg/db/account/account.go
Executable file
@@ -0,0 +1,17 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// DB is the interface which must be implemented by all db drivers
|
||||
type DB interface {
|
||||
template.DB[*model.Account]
|
||||
GetByEmail(ctx context.Context, email string) (*model.Account, error)
|
||||
GetByToken(ctx context.Context, email string) (*model.Account, error)
|
||||
GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error)
|
||||
}
|
||||
11
api/pkg/db/config.go
Normal file
11
api/pkg/db/config.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package db
|
||||
|
||||
import "github.com/tech/sendico/pkg/model"
|
||||
|
||||
type DBDriver string
|
||||
|
||||
const (
|
||||
Mongo DBDriver = "mongodb"
|
||||
)
|
||||
|
||||
type Config = model.DriverConfig[DBDriver]
|
||||
65
api/pkg/db/connection.go
Normal file
65
api/pkg/db/connection.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
mongoDriver "go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
)
|
||||
|
||||
// Connection represents a low-level database connection lifecycle.
|
||||
type Connection interface {
|
||||
Disconnect(ctx context.Context) error
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// MongoConnection provides direct access to the underlying mongo client.
|
||||
type MongoConnection struct {
|
||||
client *mongoDriver.Client
|
||||
database string
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Client() *mongoDriver.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Database() *mongoDriver.Database {
|
||||
return c.client.Database(c.database)
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Disconnect(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.client.Disconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Ping(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.client.Ping(ctx, readpref.Primary())
|
||||
}
|
||||
|
||||
// ConnectMongo returns a low-level MongoDB connection without constructing repositories.
|
||||
func ConnectMongo(logger mlogger.Logger, config *Config) (*MongoConnection, error) {
|
||||
if config == nil {
|
||||
return nil, merrors.InvalidArgument("database configuration is nil")
|
||||
}
|
||||
if config.Driver != Mongo {
|
||||
return nil, merrors.InvalidArgument("unsupported database driver: " + string(config.Driver))
|
||||
}
|
||||
|
||||
client, _, settings, err := mongoimpl.ConnectClient(logger, config.Settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MongoConnection{
|
||||
client: client,
|
||||
database: settings.Database,
|
||||
}, nil
|
||||
}
|
||||
41
api/pkg/db/factory.go
Normal file
41
api/pkg/db/factory.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
|
||||
"github.com/tech/sendico/pkg/db/invitation"
|
||||
"github.com/tech/sendico/pkg/db/organization"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
// Factory exposes high-level repositories used by application services.
|
||||
type Factory interface {
|
||||
NewRefreshTokensDB() (refreshtokens.DB, error)
|
||||
|
||||
NewAccountDB() (account.DB, error)
|
||||
NewOrganizationDB() (organization.DB, error)
|
||||
NewInvitationsDB() (invitation.DB, error)
|
||||
|
||||
NewRolesDB() (role.DB, error)
|
||||
NewPoliciesDB() (policy.DB, error)
|
||||
|
||||
TransactionFactory() transaction.Factory
|
||||
|
||||
Permissions() auth.Provider
|
||||
|
||||
CloseConnection()
|
||||
}
|
||||
|
||||
// NewConnection builds a Factory backed by the configured driver.
|
||||
func NewConnection(logger mlogger.Logger, config *Config) (Factory, error) {
|
||||
if config.Driver == Mongo {
|
||||
return mongoimpl.NewConnection(logger, config.Settings)
|
||||
}
|
||||
return nil, merrors.InvalidArgument("unknown database driver: " + string(config.Driver))
|
||||
}
|
||||
12
api/pkg/db/indexable/indexable.go
Normal file
12
api/pkg/db/indexable/indexable.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
}
|
||||
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AccountDB struct {
|
||||
template.DBImp[*model.Account]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*AccountDB, error) {
|
||||
p := &AccountDB{
|
||||
DBImp: *template.Create[*model.Account](logger, mservice.Accounts, db),
|
||||
}
|
||||
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "login", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create account database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetByToken(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Query().Filter(repository.Field("verifyToken"), email), &account)
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
@@ -0,0 +1,21 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error) {
|
||||
filter := repository.Query().Comparison(repository.IDField(), builder.In, refs)
|
||||
return mutil.GetObjects[model.Account](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
|
||||
func (db *AccountDB) GetByEmail(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Filter("login", email), &account)
|
||||
}
|
||||
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArchivableDB implements archive management for entities with model.Archivable embedded
|
||||
type ArchivableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getArchivable func(T) model.Archivable
|
||||
}
|
||||
|
||||
// NewArchivableDB creates a new ArchivableDB instance
|
||||
func NewArchivableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getArchivable func(T) model.Archivable,
|
||||
) *ArchivableDB[T] {
|
||||
return &ArchivableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getArchivable: getArchivable,
|
||||
}
|
||||
}
|
||||
|
||||
// SetArchived sets the archived status of an entity
|
||||
func (db *ArchivableDB[T]) SetArchived(ctx context.Context, objectRef primitive.ObjectID, archived bool) error {
|
||||
// Get current object to check current archived status
|
||||
obj := db.createEmpty()
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for setting archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract archivable from the object
|
||||
archivable := db.getArchivable(obj)
|
||||
currentArchived := archivable.IsArchived()
|
||||
if currentArchived == archived {
|
||||
db.logger.Debug("No change needed - same archived status",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Set the archived status
|
||||
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
|
||||
if err := db.repo.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to set archived status on object",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully set archived status on object",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsArchived checks if an entity is archived
|
||||
func (db *ArchivableDB[T]) IsArchived(ctx context.Context, objectRef primitive.ObjectID) (bool, error) {
|
||||
obj := db.createEmpty()
|
||||
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for checking archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef))
|
||||
return false, err
|
||||
}
|
||||
|
||||
archivable := db.getArchivable(obj)
|
||||
return archivable.IsArchived(), nil
|
||||
}
|
||||
|
||||
// Archive archives an entity (sets archived to true)
|
||||
func (db *ArchivableDB[T]) Archive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, true)
|
||||
}
|
||||
|
||||
// Unarchive unarchives an entity (sets archived to false)
|
||||
func (db *ArchivableDB[T]) Unarchive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, false)
|
||||
}
|
||||
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestArchivableObject represents a test object with archivable functionality
|
||||
type TestArchivableObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
model.ArchivableBase `bson:",inline" json:",inline"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) Collection() string {
|
||||
return "testArchivableObject"
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) GetArchivable() model.Archivable {
|
||||
return &t.ArchivableBase
|
||||
}
|
||||
|
||||
func TestArchivableDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MongoDB container (stable)
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("test"),
|
||||
mongodb.WithPassword("test"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := mongoContainer.Terminate(termCtx); err != nil {
|
||||
t.Logf("Failed to terminate container: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get MongoDB connection string
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to MongoDB
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI))
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Disconnect(context.Background()); err != nil {
|
||||
t.Logf("Failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping the database
|
||||
err = client.Ping(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create repository
|
||||
repo := repositoryimp.NewMongoRepository(client.Database("test_"+t.Name()), "testArchivableCollection")
|
||||
|
||||
// Create archivable DB
|
||||
archivableDB := NewArchivableDB(
|
||||
repo,
|
||||
zap.NewNop(),
|
||||
func() *TestArchivableObject { return &TestArchivableObject{} },
|
||||
func(obj *TestArchivableObject) model.Archivable { return obj.GetArchivable() },
|
||||
)
|
||||
|
||||
t.Run("SetArchived_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_NoChange", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err) // Should not error, just not change anything
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_Unarchive", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("IsArchived_True", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("IsArchived_False", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("Archive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Archive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("Unarchive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Unarchive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
}
|
||||
257
api/pkg/db/internal/mongo/db.go
Executable file
257
api/pkg/db/internal/mongo/db.go
Executable file
@@ -0,0 +1,257 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/rolesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/transactionimp"
|
||||
"github.com/tech/sendico/pkg/db/invitation"
|
||||
"github.com/tech/sendico/pkg/db/organization"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/config"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config represents configuration
|
||||
type Config struct {
|
||||
Port *string `mapstructure:"port"`
|
||||
PortEnv *string `mapstructure:"port_env"`
|
||||
User *string `mapstructure:"user"`
|
||||
UserEnv *string `mapstructure:"user_env"`
|
||||
PasswordEnv string `mapstructure:"password_env"`
|
||||
Database *string `mapstructure:"database"`
|
||||
DatabaseEnv *string `mapstructure:"database_env"`
|
||||
Host *string `mapstructure:"host"`
|
||||
HostEnv *string `mapstructure:"host_env"`
|
||||
AuthSource *string `mapstructure:"auth_source,omitempty"`
|
||||
AuthSourceEnv *string `mapstructure:"auth_source_env,omitempty"`
|
||||
AuthMechanism *string `mapstructure:"auth_mechanism,omitempty"`
|
||||
AuthMechanismEnv *string `mapstructure:"auth_mechanism_env,omitempty"`
|
||||
ReplicaSet *string `mapstructure:"replica_set,omitempty"`
|
||||
ReplicaSetEnv *string `mapstructure:"replica_set_env,omitempty"`
|
||||
Enforcer *auth.Config `mapstructure:"enforcer"`
|
||||
}
|
||||
|
||||
type DBSettings struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string
|
||||
AuthMechanism string
|
||||
ReplicaSet string
|
||||
}
|
||||
|
||||
func newProtectedDB[T any](
|
||||
db *DB,
|
||||
create func(ctx context.Context, logger mlogger.Logger, enforcer auth.Enforcer, pdb policy.DB, client *mongo.Database) (T, error),
|
||||
) (T, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
}
|
||||
|
||||
func Config2DBSettings(logger mlogger.Logger, config *Config) *DBSettings {
|
||||
p := new(DBSettings)
|
||||
p.Port = mutil.GetConfigValue(logger, "port", "port_env", config.Port, config.PortEnv)
|
||||
p.Database = mutil.GetConfigValue(logger, "database", "database_env", config.Database, config.DatabaseEnv)
|
||||
p.Password = os.Getenv(config.PasswordEnv)
|
||||
p.User = mutil.GetConfigValue(logger, "user", "user_env", config.User, config.UserEnv)
|
||||
p.Host = mutil.GetConfigValue(logger, "host", "host_env", config.Host, config.HostEnv)
|
||||
p.AuthSource = mutil.GetConfigValue(logger, "auth_source", "auth_source_env", config.AuthSource, config.AuthSourceEnv)
|
||||
p.AuthMechanism = mutil.GetConfigValue(logger, "auth_mechanism", "auth_mechanism_env", config.AuthMechanism, config.AuthMechanismEnv)
|
||||
p.ReplicaSet = mutil.GetConfigValue(logger, "replica_set", "replica_set_env", config.ReplicaSet, config.ReplicaSetEnv)
|
||||
return p
|
||||
}
|
||||
|
||||
func decodeConfig(logger mlogger.Logger, settings model.SettingsT) (*Config, *DBSettings, error) {
|
||||
var config Config
|
||||
if err := mapstructure.Decode(settings, &config); err != nil {
|
||||
logger.Warn("Failed to decode settings", zap.Error(err), zap.Any("settings", settings))
|
||||
return nil, nil, err
|
||||
}
|
||||
dbSettings := Config2DBSettings(logger, &config)
|
||||
return &config, dbSettings, nil
|
||||
}
|
||||
|
||||
func dialMongo(logger mlogger.Logger, dbSettings *DBSettings) (*mongo.Client, error) {
|
||||
cred := options.Credential{
|
||||
AuthMechanism: dbSettings.AuthMechanism,
|
||||
AuthSource: dbSettings.AuthSource,
|
||||
Username: dbSettings.User,
|
||||
Password: dbSettings.Password,
|
||||
}
|
||||
dbURI := buildURI(dbSettings)
|
||||
|
||||
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(dbURI).SetAuth(cred))
|
||||
if err != nil {
|
||||
logger.Error("Unable to connect to database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Connected successfully", zap.String("uri", dbURI))
|
||||
|
||||
if err := client.Ping(context.Background(), readpref.Primary()); err != nil {
|
||||
logger.Error("Unable to ping database", zap.Error(err))
|
||||
_ = client.Disconnect(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func ConnectClient(logger mlogger.Logger, settings model.SettingsT) (*mongo.Client, *Config, *DBSettings, error) {
|
||||
config, dbSettings, err := decodeConfig(logger, settings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
client, err := dialMongo(logger, dbSettings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return client, config, dbSettings, nil
|
||||
}
|
||||
|
||||
// DB represents the structure of the database
|
||||
type DB struct {
|
||||
logger mlogger.Logger
|
||||
config *DBSettings
|
||||
client *mongo.Client
|
||||
enforcer auth.Enforcer
|
||||
manager auth.Manager
|
||||
pdb policy.DB
|
||||
}
|
||||
|
||||
func (db *DB) db() *mongo.Database {
|
||||
return db.client.Database(db.config.Database)
|
||||
}
|
||||
|
||||
func (db *DB) NewAccountDB() (account.DB, error) {
|
||||
return accountdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewOrganizationDB() (organization.DB, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizationDB, err := organizationdb.Create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the concrete type - interface mismatch will be handled at runtime
|
||||
// TODO: Update organization.DB interface to match implementation signatures
|
||||
return organizationDB, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRefreshTokensDB() (refreshtokens.DB, error) {
|
||||
return refreshtokensdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewInvitationsDB() (invitation.DB, error) {
|
||||
return newProtectedDB(db, invitationdb.Create)
|
||||
}
|
||||
|
||||
func (db *DB) NewPoliciesDB() (policy.DB, error) {
|
||||
return db.pdb, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRolesDB() (role.DB, error) {
|
||||
return rolesdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) TransactionFactory() transaction.Factory {
|
||||
return transactionimp.CreateFactory(db.client)
|
||||
}
|
||||
|
||||
func (db *DB) Permissions() auth.Provider {
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Manager() auth.Manager {
|
||||
return db.manager
|
||||
}
|
||||
|
||||
func (db *DB) Enforcer() auth.Enforcer {
|
||||
return db.enforcer
|
||||
}
|
||||
|
||||
func (db *DB) GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error) {
|
||||
var policyDescription model.PolicyDescription
|
||||
return &policyDescription, db.pdb.FindOne(ctx, repository.Filter("resourceTypes", resource), &policyDescription)
|
||||
}
|
||||
|
||||
func (db *DB) CloseConnection() {
|
||||
if err := db.client.Disconnect(context.Background()); err != nil {
|
||||
db.logger.Warn("Failed to close connection", zap.Error(err))
|
||||
}
|
||||
db.logger.Info("Database connection closed")
|
||||
}
|
||||
|
||||
// NewConnection creates a new database connection
|
||||
func NewConnection(logger mlogger.Logger, settings model.SettingsT) (*DB, error) {
|
||||
client, config, dbSettings, err := ConnectClient(logger, settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &DB{
|
||||
logger: logger.Named("db"),
|
||||
config: dbSettings,
|
||||
client: client,
|
||||
}
|
||||
|
||||
cleanup := func(ctx context.Context) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
logger.Warn("Failed to close MongoDB connection", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
rdb, err := db.NewRolesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create roles database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.pdb, err = policiesdb.Create(db.logger, db.db()); err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.enforcer, db.manager, err = auth.CreateAuth(logger, db.client, db.db(), db.pdb, rdb, config.Enforcer); err != nil {
|
||||
db.logger.Warn("Failed to create permissions enforcer", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Indexable Implementation (Refactored)
|
||||
|
||||
## Overview
|
||||
|
||||
This package provides a refactored implementation of the `indexable.DB` interface that uses `mutil.GetObjects` for better consistency with the existing codebase. The implementation has been moved to the mongo folder and includes a factory for project indexable in the pkg/db folder.
|
||||
|
||||
## Structure
|
||||
|
||||
### 1. `api/pkg/db/internal/mongo/indexable/indexable.go`
|
||||
- **`ReorderTemplate[T]`**: Generic template function that uses `mutil.GetObjects` for fetching objects
|
||||
- **`IndexableDB`**: Base struct for creating concrete implementations
|
||||
- **Type-safe implementation**: Uses Go generics with proper type constraints
|
||||
|
||||
### 2. `api/pkg/db/project_indexable.go`
|
||||
- **`ProjectIndexableDB`**: Factory implementation for Project objects
|
||||
- **`NewProjectIndexableDB`**: Constructor function
|
||||
- **`ReorderTemplate`**: Duplicate of the mongo version for convenience
|
||||
|
||||
## Key Changes from Previous Implementation
|
||||
|
||||
### 1. **Uses `mutil.GetObjects`**
|
||||
```go
|
||||
// Old implementation (manual cursor handling)
|
||||
err = repo.FindManyByFilter(ctx, filter, func(cursor *mongo.Cursor) error {
|
||||
var obj T
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
objects = append(objects, obj)
|
||||
return nil
|
||||
})
|
||||
|
||||
// New implementation (using mutil.GetObjects)
|
||||
objects, err := mutil.GetObjects[T](
|
||||
ctx,
|
||||
logger,
|
||||
filterFunc().
|
||||
And(
|
||||
repository.IndexOpFilter(minIdx, builder.Gte),
|
||||
repository.IndexOpFilter(maxIdx, builder.Lte),
|
||||
),
|
||||
nil, nil, nil, // limit, offset, isArchived
|
||||
repo,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. **Moved to Mongo Folder**
|
||||
- Location: `api/pkg/db/internal/mongo/indexable/`
|
||||
- Consistent with other mongo implementations
|
||||
- Better organization within the codebase
|
||||
|
||||
### 3. **Added Factory in pkg/db**
|
||||
- Location: `api/pkg/db/project_indexable.go`
|
||||
- Provides easy access to project indexable functionality
|
||||
- Includes logger parameter for better error handling
|
||||
|
||||
## Usage
|
||||
|
||||
### Using the Factory (Recommended)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create a project indexable DB
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder a project
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
```
|
||||
|
||||
### Using the Template Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// Define helper functions
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
updateIndexable := func(p *model.Project, newIndex int) {
|
||||
p.Index = newIndex
|
||||
}
|
||||
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
filterFunc := func() builder.Query {
|
||||
return repository.OrgFilter(organizationRef)
|
||||
}
|
||||
|
||||
// Use the template
|
||||
err := indexable.ReorderTemplate(
|
||||
ctx,
|
||||
logger,
|
||||
repo,
|
||||
objectRef,
|
||||
newIndex,
|
||||
filterFunc,
|
||||
getIndexable,
|
||||
updateIndexable,
|
||||
createEmpty,
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits of Refactoring
|
||||
|
||||
1. **Consistency**: Uses `mutil.GetObjects` like other parts of the codebase
|
||||
2. **Better Error Handling**: Includes logger parameter for proper error logging
|
||||
3. **Organization**: Moved to appropriate folder structure
|
||||
4. **Factory Pattern**: Easy-to-use factory for common use cases
|
||||
5. **Type Safety**: Maintains compile-time type checking
|
||||
6. **Performance**: Leverages existing optimized `mutil.GetObjects` implementation
|
||||
|
||||
## Testing
|
||||
|
||||
### Mongo Implementation Tests
|
||||
```bash
|
||||
go test ./db/internal/mongo/indexable -v
|
||||
```
|
||||
|
||||
### Factory Tests
|
||||
```bash
|
||||
go test ./db -v
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
The refactored implementation is ready for integration with existing project reordering APIs. The factory pattern makes it easy to add reordering functionality to any service that needs to reorder projects within an organization.
|
||||
|
||||
## Migration from Old Implementation
|
||||
|
||||
If you were using the old implementation:
|
||||
|
||||
1. **Update imports**: Change from `api/pkg/db/internal/indexable` to `api/pkg/db`
|
||||
2. **Use factory**: Replace manual template usage with `NewProjectIndexableDB`
|
||||
3. **Add logger**: Include a logger parameter in your constructor calls
|
||||
4. **Update tests**: Use the new test structure if needed
|
||||
|
||||
The API remains the same, so existing code should work with minimal changes.
|
||||
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# Indexable Usage Guide
|
||||
|
||||
## Generic Implementation for Any Indexable Struct
|
||||
|
||||
The implementation is now **generic** and supports **any struct that embeds `model.Indexable`**!
|
||||
|
||||
- **Interface**: `api/pkg/db/indexable.go` - defines the contract
|
||||
- **Implementation**: `api/pkg/db/internal/mongo/indexable/` - generic implementation
|
||||
- **Factory**: `api/pkg/db/project_indexable.go` - convenient factory for projects
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Using the Generic Implementation Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// For any type that embeds model.Indexable, define helper functions:
|
||||
createEmpty := func() *YourType {
|
||||
return &YourType{}
|
||||
}
|
||||
|
||||
getIndexable := func(obj *YourType) *model.Indexable {
|
||||
return &obj.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB
|
||||
indexableDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with single filter parameter
|
||||
err := indexableDB.Reorder(ctx, objectID, newIndex, filter)
|
||||
```
|
||||
|
||||
### 2. Using the Project Factory (Recommended for Projects)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create project indexable DB (automatically applies org filter)
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder project (org filter applied automatically)
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, repository.Query())
|
||||
|
||||
// Reorder with additional filters (combined with org filter)
|
||||
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, additionalFilter)
|
||||
```
|
||||
|
||||
## Examples for Different Types
|
||||
|
||||
### Project IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
projectDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(ctx, projectID, 2, orgFilter)
|
||||
```
|
||||
|
||||
### Status IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Status {
|
||||
return &model.Status{}
|
||||
}
|
||||
|
||||
getIndexable := func(s *model.Status) *model.Indexable {
|
||||
return &s.Indexable
|
||||
}
|
||||
|
||||
statusDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
|
||||
statusDB.Reorder(ctx, statusID, 1, projectFilter)
|
||||
```
|
||||
|
||||
### Task IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
taskDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(ctx, taskID, 3, statusFilter)
|
||||
```
|
||||
|
||||
### Priority IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Priority {
|
||||
return &model.Priority{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Priority) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
priorityDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
priorityDB.Reorder(ctx, priorityID, 0, orgFilter)
|
||||
```
|
||||
|
||||
### Global Reordering (No Filter)
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
globalDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
// Reorders all items globally (empty filter)
|
||||
globalDB.Reorder(ctx, objectID, 5, repository.Query())
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ **Generic Support**
|
||||
- Works with **any struct** that embeds `model.Indexable`
|
||||
- Type-safe with compile-time checking
|
||||
- No hardcoded types
|
||||
|
||||
### ✅ **Single Filter Parameter**
|
||||
- **Simple**: Single `builder.Query` parameter instead of variadic `interface{}`
|
||||
- **Flexible**: Can incorporate any combination of filters
|
||||
- **Type-safe**: No runtime type assertions needed
|
||||
|
||||
### ✅ **Clean Architecture**
|
||||
- Interface separated from implementation
|
||||
- Generic implementation in internal package
|
||||
- Easy-to-use factories for common types
|
||||
|
||||
## How It Works
|
||||
|
||||
### Generic Algorithm
|
||||
1. **Get current index** using type-specific helper function
|
||||
2. **If no change needed** → return early
|
||||
3. **Apply filter** to scope affected items
|
||||
4. **Shift affected items** using `PatchMany` with `$inc`
|
||||
5. **Update target object** using `Patch` with `$set`
|
||||
|
||||
### Type-Safe Implementation
|
||||
```go
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// Single filter parameter - clean and simple
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Generic** - Works with any Indexable struct
|
||||
✅ **Type Safe** - Compile-time type checking
|
||||
✅ **Simple** - Single filter parameter instead of variadic interface{}
|
||||
✅ **Efficient** - Uses patches, not full updates
|
||||
✅ **Clean** - Interface separated from implementation
|
||||
|
||||
That's it! **Generic, type-safe, and simple** reordering for any Indexable struct with a single filter parameter.
|
||||
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Example usage of the generic IndexableDB with different types
|
||||
|
||||
// Example 1: Using with Project
|
||||
func ExampleProjectIndexableDB(repo repository.Repository, logger mlogger.Logger, organizationRef primitive.ObjectID) {
|
||||
// Define helper functions for Project
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Project
|
||||
projectDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(context.Background(), primitive.NewObjectID(), 2, orgFilter)
|
||||
}
|
||||
|
||||
// Example 3: Using with Task
|
||||
func ExampleTaskIndexableDB(repo repository.Repository, logger mlogger.Logger, statusRef primitive.ObjectID) {
|
||||
// Define helper functions for Task
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Task
|
||||
taskDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with status filter
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(context.Background(), primitive.NewObjectID(), 3, statusFilter)
|
||||
}
|
||||
|
||||
// Example 5: Using without any filter (global reordering)
|
||||
func ExampleGlobalIndexableDB(repo repository.Repository, logger mlogger.Logger) {
|
||||
// Define helper functions for any Indexable type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB without filters
|
||||
globalDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use without any filter - reorders all items globally
|
||||
globalDB.Reorder(context.Background(), primitive.NewObjectID(), 5, repository.Query())
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// IndexableDB implements db.IndexableDB interface with generic support
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// NewIndexableDB creates a new IndexableDB instance
|
||||
func NewIndexableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getIndexable func(T) *model.Indexable,
|
||||
) *IndexableDB[T] {
|
||||
return &IndexableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getIndexable: getIndexable,
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder implements the db.IndexableDB interface with single filter parameter
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
// Get current object to find its index
|
||||
obj := db.createEmpty()
|
||||
err := db.repo.Get(ctx, objectRef, obj)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to get object for reordering",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract index from the object
|
||||
indexable := db.getIndexable(obj)
|
||||
currentIndex := indexable.Index
|
||||
if currentIndex == newIndex {
|
||||
db.logger.Debug("No reordering needed - same index",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Simple reordering logic
|
||||
if currentIndex < newIndex {
|
||||
// Moving down: shift items between currentIndex+1 and newIndex up by -1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), -1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
|
||||
And(repository.IndexOpFilter(newIndex, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving down)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving down)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
} else {
|
||||
// Moving up: shift items between newIndex and currentIndex-1 down by +1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), 1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(newIndex, builder.Gte)).
|
||||
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving up)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving up)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
}
|
||||
|
||||
// Update the target object to new index
|
||||
patch := repository.Patch().Set(repository.IndexField(), newIndex)
|
||||
err = db.repo.Patch(ctx, objectRef, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to update target object index",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Info("Successfully reordered object",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("old_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil
|
||||
}
|
||||
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (repository.Repository, func()) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
db := client.Database("testdb")
|
||||
repo := repository.CreateMongoRepository(db, "projects")
|
||||
|
||||
cleanup := func() {
|
||||
disconnect(ctx, t, client)
|
||||
terminate(ctx, t, mongoContainer)
|
||||
}
|
||||
|
||||
return repo, cleanup
|
||||
}
|
||||
|
||||
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
t.Logf("failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func terminate(ctx context.Context, t *testing.T, container testcontainers.Container) {
|
||||
if err := container.Terminate(ctx); err != nil {
|
||||
t.Logf("failed to terminate MongoDB container: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexableDB_Reorder(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create test projects with different indices
|
||||
projects := []*model.Project{
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project A"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "A",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project B"},
|
||||
Indexable: model.Indexable{Index: 1},
|
||||
Mnemonic: "B",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project C"},
|
||||
Indexable: model.Indexable{Index: 2},
|
||||
Mnemonic: "C",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project D"},
|
||||
Indexable: model.Indexable{Index: 3},
|
||||
Mnemonic: "D",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Insert projects into database
|
||||
for _, project := range projects {
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_NoChange", func(t *testing.T) {
|
||||
// Test reordering to the same position (should be no-op)
|
||||
err := indexableDB.Reorder(ctx, projects[1].ID, 1, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify indices haven't changed
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveDown", func(t *testing.T) {
|
||||
// Move Project A (index 0) to index 2
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project A should now be at index 2
|
||||
// Project B should be at index 0
|
||||
// Project C should be at index 1
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project A (moved to index 2)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 0)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project C (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveUp", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Move Project C (index 2) to index 0
|
||||
err := indexableDB.Reorder(ctx, projects[2].ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project C should now be at index 0
|
||||
// Project A should be at index 1
|
||||
// Project B should be at index 2
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project C (moved to index 0)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project A (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 2)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_WithFilter", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test reordering with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, orgFilter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering worked with filter
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndexableDB_EdgeCases(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create a single project for edge case testing
|
||||
project := &model.Project{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Project"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "TEST",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
}
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_SingleItem", func(t *testing.T) {
|
||||
// Test reordering a single item (should work but have no effect)
|
||||
err := indexableDB.Reorder(ctx, project.ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, project.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_InvalidObjectID", func(t *testing.T) {
|
||||
// Test reordering with an invalid object ID
|
||||
invalidID := primitive.NewObjectID()
|
||||
err := indexableDB.Reorder(ctx, invalidID, 1, repository.Query())
|
||||
require.Error(t, err) // Should fail because object doesn't exist
|
||||
})
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Accept(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationAccepted)
|
||||
}
|
||||
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an invitation
|
||||
// Invitation supports archiving through PermissionBound embedding ArchivableBase
|
||||
func (db *InvitationDB) SetArchived(ctx context.Context, accountRef, organizationRef, invitationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, invitationRef, model.ActionUpdate)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to enforce archivation permission", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.DBImp.Logger.Debug("Permission denied for archivation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.AccessDenied(db.Collection, string(model.ActionUpdate), invitationRef)
|
||||
}
|
||||
|
||||
// Get the invitation first
|
||||
var invitation model.Invitation
|
||||
if err := db.Get(ctx, accountRef, invitationRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving invitation for archival", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the invitation's archived status
|
||||
invitation.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating invitation archived status", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: Currently no cascade dependencies for invitations
|
||||
// If cascade is enabled, we could add logic here for any future dependencies
|
||||
if cascade {
|
||||
db.DBImp.Logger.Debug("Cascade archiving requested but no dependencies to archive for invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an invitation
|
||||
// Invitations don't have cascade dependencies, so this is a simple deletion
|
||||
func (db *InvitationDB) DeleteCascade(ctx context.Context, accountRef, invitationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting invitation cascade deletion", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
|
||||
// Delete the invitation itself (no dependencies to cascade delete)
|
||||
if err := db.Delete(ctx, accountRef, invitationRef); err != nil {
|
||||
db.DBImp.Logger.Error("Error deleting invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil
|
||||
}
|
||||
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type InvitationDB struct {
|
||||
auth.ProtectedDBImp[*model.Invitation]
|
||||
}
|
||||
|
||||
func Create(
|
||||
ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*InvitationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Invitation](ctx, logger, pdb, enforcer, mservice.Invitations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unique email per organization
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: repository.OrgField().Build(), Sort: ri.Asc}, {Field: "description.email", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Error("Failed to create unique mnemonic index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ttl index
|
||||
ttl := int32(0) // zero ttl means expiration on date preset when inserting data
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},
|
||||
TTL: &ttl,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Warn("Failed to create ttl index in the invitations", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InvitationDB{ProtectedDBImp: *p}, nil
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Decline(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationDeclined)
|
||||
}
|
||||
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error) {
|
||||
roleField := repository.Field("role")
|
||||
orgField := repository.Field("organization")
|
||||
accField := repository.Field("account")
|
||||
empField := repository.Field("employee")
|
||||
regField := repository.Field("registrationAcc")
|
||||
descEmailField := repository.Field("description").Dot("email")
|
||||
pipeline := repository.Pipeline().
|
||||
// 0) Filter to exactly the invitation(s) you want
|
||||
Match(repository.IDFilter(invitationRef).And(repository.Filter("status", model.InvitationCreated))).
|
||||
// 1) Lookup the role document
|
||||
Lookup(
|
||||
mservice.Roles,
|
||||
repository.Field("roleRef"),
|
||||
repository.IDField(),
|
||||
roleField,
|
||||
).
|
||||
Unwind(repository.Ref(roleField)).
|
||||
// 2) Lookup the organization document
|
||||
Lookup(
|
||||
mservice.Organizations,
|
||||
repository.Field("organizationRef"),
|
||||
repository.IDField(),
|
||||
orgField,
|
||||
).
|
||||
Unwind(repository.Ref(orgField)).
|
||||
// 3) Lookup the account document
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
repository.Field("inviterRef"),
|
||||
repository.IDField(),
|
||||
accField,
|
||||
).
|
||||
Unwind(repository.Ref(accField)).
|
||||
/* 4) do we already have an account whose login == invitation.description ? */
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
descEmailField, // local field (invitation.description.email)
|
||||
repository.Field("login"), // foreign field (account.login)
|
||||
regField, // array: 0-length or ≥1
|
||||
).
|
||||
// 5) Projection
|
||||
Project(
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("description"),
|
||||
repository.Ref(accField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("avatarUrl"),
|
||||
repository.Ref(accField.Dot("avatarUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("description"),
|
||||
repository.Ref(orgField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("logoUrl"),
|
||||
repository.Ref(orgField.Dot("logoUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
roleField,
|
||||
repository.Ref(roleField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("invitation"), // ← left-hand side
|
||||
repository.Ref(repository.Field("description")), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("storable"), // ← left-hand side
|
||||
repository.RootRef(), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.ProjectionExpr(
|
||||
repository.Field("registrationRequired"),
|
||||
repository.Eq(
|
||||
repository.Size(repository.Value(repository.Ref(regField).Build())),
|
||||
repository.Literal(0),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
var res model.PublicInvitation
|
||||
haveResult := false
|
||||
decoder := func(cur *mongo.Cursor) error {
|
||||
if haveResult {
|
||||
// should never get here
|
||||
db.DBImp.Logger.Warn("Unexpected extra invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.Internal("Unexpected extra invitation found by reference")
|
||||
}
|
||||
if e := cur.Decode(&res); e != nil {
|
||||
db.DBImp.Logger.Warn("Failed to decode entity", zap.Error(e), zap.Any("data", cur.Current.String()))
|
||||
return e
|
||||
}
|
||||
haveResult = true
|
||||
return nil
|
||||
}
|
||||
if err := db.DBImp.Repository.Aggregate(ctx, pipeline, decoder); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to execute aggregation pipeline", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, err
|
||||
}
|
||||
if !haveResult {
|
||||
db.DBImp.Logger.Warn("No results fetched", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, merrors.NoData(fmt.Sprintf("Invitation %s not found", invitationRef.Hex()))
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mauth "github.com/tech/sendico/pkg/mutil/db/auth"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error) {
|
||||
res, err := mauth.GetProtectedObjects[model.Invitation](
|
||||
ctx,
|
||||
db.DBImp.Logger,
|
||||
accountRef, organizationRef, model.ActionRead,
|
||||
repository.OrgFilter(organizationRef),
|
||||
cursor,
|
||||
db.Enforcer,
|
||||
db.DBImp.Repository,
|
||||
)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
return []model.Invitation{}, nil
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) updateStatus(ctx context.Context, invitationRef primitive.ObjectID, newStatus model.InvitationStatus) error {
|
||||
// db.DBImp.Up
|
||||
var inv model.Invitation
|
||||
if err := db.DBImp.FindOne(ctx, repository.IDFilter(invitationRef), &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to fetch invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
inv.Status = newStatus
|
||||
if err := db.DBImp.Update(ctx, &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to update invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
22
api/pkg/db/internal/mongo/mongo.go
Normal file
22
api/pkg/db/internal/mongo/mongo.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func buildURI(s *DBSettings) string {
|
||||
u := &url.URL{
|
||||
Scheme: "mongodb",
|
||||
Host: s.Host,
|
||||
Path: "/" + url.PathEscape(s.Database), // /my%20db
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
if s.ReplicaSet != "" {
|
||||
q.Set("replicaSet", s.ReplicaSet)
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an organization and optionally cascades to projects, tasks, comments, and reactions
|
||||
func (db *OrganizationDB) SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
|
||||
// Get the organization first
|
||||
var organization model.Organization
|
||||
if err := db.Get(ctx, accountRef, organizationRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving organization for archival", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the organization's archived status
|
||||
organization.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating organization archived status", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an organization and all its related data (projects, tasks, comments, reactions, statuses)
|
||||
func (db *OrganizationDB) DeleteCascade(ctx context.Context, organizationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting organization deletion with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
|
||||
// Delete the organization itself
|
||||
if err := db.Unprotected().Delete(ctx, organizationRef); err != nil {
|
||||
db.DBImp.Logger.Warn("Error deleting organization", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted organization with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil
|
||||
}
|
||||
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) Create(ctx context.Context, _, _ primitive.ObjectID, org *model.Organization) error {
|
||||
if org == nil {
|
||||
return merrors.InvalidArgument("Organization object is nil")
|
||||
}
|
||||
org.SetID(primitive.NewObjectID())
|
||||
// Organizaiton reference must be set to the same value as own organization reference
|
||||
org.SetOrganizationRef(*org.GetID())
|
||||
return db.DBImp.Create(ctx, org)
|
||||
}
|
||||
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type OrganizationDB struct {
|
||||
auth.ProtectedDBImp[*model.Organization]
|
||||
}
|
||||
|
||||
func Create(ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*OrganizationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Organization](ctx, logger, pdb, enforcer, mservice.Organizations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &OrganizationDB{
|
||||
ProtectedDBImp: *p,
|
||||
}
|
||||
p.DBImp.SetDeleter(res.DeleteCascade)
|
||||
return res, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user