Files
sendico/api/ledger/internal/service/ledger/invariant.go
2026-01-31 00:26:42 +01:00

142 lines
3.9 KiB
Go

package ledger
import (
"context"
"errors"
"fmt"
"strings"
"github.com/shopspring/decimal"
"github.com/tech/sendico/ledger/storage"
storageMongo "github.com/tech/sendico/ledger/storage/mongo"
"github.com/tech/sendico/pkg/merrors"
pmodel "github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/v2/bson"
)
// CheckExternalInvariant validates the external_source/external_sink invariant for a currency.
func (s *Service) CheckExternalInvariant(ctx context.Context, currency string) error {
if s == nil || s.storage == nil {
return errStorageNotInitialized
}
normalized := strings.ToUpper(strings.TrimSpace(currency))
if normalized == "" {
return merrors.InvalidArgument("currency is required")
}
source, err := s.systemAccount(ctx, pmodel.SystemAccountPurposeExternalSource, normalized)
if err != nil {
return err
}
sink, err := s.systemAccount(ctx, pmodel.SystemAccountPurposeExternalSink, normalized)
if err != nil {
return err
}
sourceBalance, err := s.balanceForAccount(ctx, source)
if err != nil {
return err
}
sinkBalance, err := s.balanceForAccount(ctx, sink)
if err != nil {
return err
}
orgTotal, err := s.sumOrganizationBalances(ctx, normalized)
if err != nil {
return err
}
diff := sourceBalance.Abs().Sub(sinkBalance.Abs())
if !diff.Equal(orgTotal) {
return merrors.InvalidArgument(fmt.Sprintf("external invariant failed: abs(source)=%s abs(sink)=%s org_total=%s", sourceBalance.Abs().String(), sinkBalance.Abs().String(), orgTotal.String()))
}
return nil
}
func (s *Service) balanceForAccount(ctx context.Context, account *pmodel.LedgerAccount) (decimal.Decimal, error) {
if account == nil || account.GetID() == nil {
return decimal.Zero, merrors.InvalidArgument("account reference is required")
}
balance, err := s.storage.Balances().Get(ctx, *account.GetID())
if err != nil {
if errors.Is(err, storage.ErrBalanceNotFound) {
return decimal.Zero, nil
}
return decimal.Zero, err
}
return parseDecimal(balance.Balance)
}
func (s *Service) sumOrganizationBalances(ctx context.Context, currency string) (decimal.Decimal, error) {
sum := decimal.Zero
accounts, err := s.listOrganizationAccounts(ctx, currency)
if err != nil {
return decimal.Zero, err
}
for _, account := range accounts {
if account == nil || account.GetID() == nil {
return decimal.Zero, merrors.Internal("account missing identifier")
}
if account.OrganizationRef == nil || account.OrganizationRef.IsZero() {
continue
}
balance, err := s.storage.Balances().Get(ctx, *account.GetID())
if err != nil {
if errors.Is(err, storage.ErrBalanceNotFound) {
continue
}
return decimal.Zero, err
}
amount, err := parseDecimal(balance.Balance)
if err != nil {
return decimal.Zero, err
}
sum = sum.Add(amount)
}
return sum, nil
}
type accountCurrencyLister interface {
ListByCurrency(ctx context.Context, currency string) ([]*pmodel.LedgerAccount, error)
}
func (s *Service) listOrganizationAccounts(ctx context.Context, currency string) ([]*pmodel.LedgerAccount, error) {
if lister, ok := s.storage.Accounts().(accountCurrencyLister); ok {
return lister.ListByCurrency(ctx, currency)
}
store, ok := s.storage.(*storageMongo.Store)
if !ok {
return nil, merrors.Internal("storage does not support invariant checks")
}
collection := store.Database().Collection(mservice.LedgerAccounts)
filter := bson.M{
"currency": currency,
"$or": []bson.M{
{"scope": pmodel.LedgerAccountScopeOrganization},
{"scope": ""},
{"scope": bson.M{"$exists": false}},
},
}
cursor, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
accounts := make([]*pmodel.LedgerAccount, 0)
for cursor.Next(ctx) {
account := &pmodel.LedgerAccount{}
if err := cursor.Decode(account); err != nil {
return nil, err
}
accounts = append(accounts, account)
}
if err := cursor.Err(); err != nil {
return nil, err
}
return accounts, nil
}