142 lines
3.9 KiB
Go
142 lines
3.9 KiB
Go
package ledger
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/shopspring/decimal"
|
|
"github.com/tech/sendico/ledger/storage"
|
|
storageMongo "github.com/tech/sendico/ledger/storage/mongo"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
pmodel "github.com/tech/sendico/pkg/model"
|
|
"github.com/tech/sendico/pkg/mservice"
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
|
)
|
|
|
|
// CheckExternalInvariant validates the external_source/external_sink invariant for a currency.
|
|
func (s *Service) CheckExternalInvariant(ctx context.Context, currency string) error {
|
|
if s == nil || s.storage == nil {
|
|
return errStorageNotInitialized
|
|
}
|
|
normalized := strings.ToUpper(strings.TrimSpace(currency))
|
|
if normalized == "" {
|
|
return merrors.InvalidArgument("currency is required")
|
|
}
|
|
|
|
source, err := s.systemAccount(ctx, pmodel.SystemAccountPurposeExternalSource, normalized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sink, err := s.systemAccount(ctx, pmodel.SystemAccountPurposeExternalSink, normalized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sourceBalance, err := s.balanceForAccount(ctx, source)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sinkBalance, err := s.balanceForAccount(ctx, sink)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
orgTotal, err := s.sumOrganizationBalances(ctx, normalized)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
diff := sourceBalance.Abs().Sub(sinkBalance.Abs())
|
|
if !diff.Equal(orgTotal) {
|
|
return merrors.InvalidArgument(fmt.Sprintf("external invariant failed: abs(source)=%s abs(sink)=%s org_total=%s", sourceBalance.Abs().String(), sinkBalance.Abs().String(), orgTotal.String()))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) balanceForAccount(ctx context.Context, account *pmodel.LedgerAccount) (decimal.Decimal, error) {
|
|
if account == nil || account.GetID() == nil {
|
|
return decimal.Zero, merrors.InvalidArgument("account reference is required")
|
|
}
|
|
balance, err := s.storage.Balances().Get(ctx, *account.GetID())
|
|
if err != nil {
|
|
if errors.Is(err, storage.ErrBalanceNotFound) {
|
|
return decimal.Zero, nil
|
|
}
|
|
return decimal.Zero, err
|
|
}
|
|
return parseDecimal(balance.Balance)
|
|
}
|
|
|
|
func (s *Service) sumOrganizationBalances(ctx context.Context, currency string) (decimal.Decimal, error) {
|
|
sum := decimal.Zero
|
|
accounts, err := s.listOrganizationAccounts(ctx, currency)
|
|
if err != nil {
|
|
return decimal.Zero, err
|
|
}
|
|
for _, account := range accounts {
|
|
if account == nil || account.GetID() == nil {
|
|
return decimal.Zero, merrors.Internal("account missing identifier")
|
|
}
|
|
if account.OrganizationRef == nil || account.OrganizationRef.IsZero() {
|
|
continue
|
|
}
|
|
balance, err := s.storage.Balances().Get(ctx, *account.GetID())
|
|
if err != nil {
|
|
if errors.Is(err, storage.ErrBalanceNotFound) {
|
|
continue
|
|
}
|
|
return decimal.Zero, err
|
|
}
|
|
amount, err := parseDecimal(balance.Balance)
|
|
if err != nil {
|
|
return decimal.Zero, err
|
|
}
|
|
sum = sum.Add(amount)
|
|
}
|
|
return sum, nil
|
|
}
|
|
|
|
type accountCurrencyLister interface {
|
|
ListByCurrency(ctx context.Context, currency string) ([]*pmodel.LedgerAccount, error)
|
|
}
|
|
|
|
func (s *Service) listOrganizationAccounts(ctx context.Context, currency string) ([]*pmodel.LedgerAccount, error) {
|
|
if lister, ok := s.storage.Accounts().(accountCurrencyLister); ok {
|
|
return lister.ListByCurrency(ctx, currency)
|
|
}
|
|
|
|
store, ok := s.storage.(*storageMongo.Store)
|
|
if !ok {
|
|
return nil, merrors.Internal("storage does not support invariant checks")
|
|
}
|
|
collection := store.Database().Collection(mservice.LedgerAccounts)
|
|
filter := bson.M{
|
|
"currency": currency,
|
|
"$or": []bson.M{
|
|
{"scope": pmodel.LedgerAccountScopeOrganization},
|
|
{"scope": ""},
|
|
{"scope": bson.M{"$exists": false}},
|
|
},
|
|
}
|
|
cursor, err := collection.Find(ctx, filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer cursor.Close(ctx)
|
|
|
|
accounts := make([]*pmodel.LedgerAccount, 0)
|
|
for cursor.Next(ctx) {
|
|
account := &pmodel.LedgerAccount{}
|
|
if err := cursor.Decode(account); err != nil {
|
|
return nil, err
|
|
}
|
|
accounts = append(accounts, account)
|
|
}
|
|
if err := cursor.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return accounts, nil
|
|
}
|