package ledger import ( "context" "strings" "testing" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" "github.com/tech/sendico/ledger/storage" "github.com/tech/sendico/ledger/storage/model" "github.com/tech/sendico/pkg/merrors" ledgerv1 "github.com/tech/sendico/pkg/proto/ledger/v1" ) type accountStoreStub struct { createErr error createErrSettlement error created []*model.Account existing *model.Account existingErr error defaultSettlement *model.Account defaultErr error createErrs []error } func (s *accountStoreStub) Create(_ context.Context, account *model.Account) error { if account.IsSettlement { if s.createErrSettlement != nil { return s.createErrSettlement } } else { if len(s.createErrs) > 0 { err := s.createErrs[0] s.createErrs = s.createErrs[1:] if err != nil { return err } } else if s.createErr != nil { return s.createErr } } if account.GetID() == nil || account.GetID().IsZero() { account.SetID(primitive.NewObjectID()) } account.CreatedAt = account.CreatedAt.UTC() account.UpdatedAt = account.UpdatedAt.UTC() s.created = append(s.created, account) return nil } func (s *accountStoreStub) GetByAccountCode(_ context.Context, _ primitive.ObjectID, _ string, _ string) (*model.Account, error) { if s.existingErr != nil { return nil, s.existingErr } return s.existing, nil } func (s *accountStoreStub) Get(context.Context, primitive.ObjectID) (*model.Account, error) { return nil, storage.ErrAccountNotFound } func (s *accountStoreStub) GetDefaultSettlement(context.Context, primitive.ObjectID, string) (*model.Account, error) { if s.defaultErr != nil { return nil, s.defaultErr } if s.defaultSettlement != nil { return s.defaultSettlement, nil } return nil, storage.ErrAccountNotFound } func (s *accountStoreStub) ListByOrganization(context.Context, primitive.ObjectID, int, int) ([]*model.Account, error) { return nil, nil } func (s *accountStoreStub) UpdateStatus(context.Context, primitive.ObjectID, model.AccountStatus) error { return nil } type repositoryStub struct { accounts storage.AccountsStore } func (r *repositoryStub) Ping(context.Context) error { return nil } func (r *repositoryStub) Accounts() storage.AccountsStore { return r.accounts } func (r *repositoryStub) JournalEntries() storage.JournalEntriesStore { return nil } func (r *repositoryStub) PostingLines() storage.PostingLinesStore { return nil } func (r *repositoryStub) Balances() storage.BalancesStore { return nil } func (r *repositoryStub) Outbox() storage.OutboxStore { return nil } func TestCreateAccountResponder_Success(t *testing.T) { t.Parallel() orgRef := primitive.NewObjectID() accountStore := &accountStoreStub{} svc := &Service{ logger: zap.NewNop(), storage: &repositoryStub{accounts: accountStore}, } req := &ledgerv1.CreateAccountRequest{ OrganizationRef: orgRef.Hex(), AccountType: ledgerv1.AccountType_ACCOUNT_TYPE_ASSET, Currency: "usd", AllowNegative: false, IsSettlement: true, Metadata: map[string]string{"purpose": "primary"}, } resp, err := svc.createAccountResponder(context.Background(), req)(context.Background()) require.NoError(t, err) require.NotNil(t, resp) require.NotNil(t, resp.Account) parts := strings.Split(resp.Account.AccountCode, ":") require.Len(t, parts, 3) require.Equal(t, "asset", parts[0]) require.Equal(t, "usd", parts[1]) require.Len(t, parts[2], 24) require.Equal(t, ledgerv1.AccountType_ACCOUNT_TYPE_ASSET, resp.Account.AccountType) require.Equal(t, "USD", resp.Account.Currency) require.True(t, resp.Account.IsSettlement) require.Contains(t, resp.Account.Metadata, "purpose") require.NotEmpty(t, resp.Account.LedgerAccountRef) require.Len(t, accountStore.created, 1) } func TestCreateAccountResponder_AutoCreatesSettlementAccount(t *testing.T) { t.Parallel() orgRef := primitive.NewObjectID() accountStore := &accountStoreStub{} svc := &Service{ logger: zap.NewNop(), storage: &repositoryStub{accounts: accountStore}, } req := &ledgerv1.CreateAccountRequest{ OrganizationRef: orgRef.Hex(), AccountType: ledgerv1.AccountType_ACCOUNT_TYPE_LIABILITY, Currency: "usd", } resp, err := svc.createAccountResponder(context.Background(), req)(context.Background()) require.NoError(t, err) require.NotNil(t, resp) require.NotNil(t, resp.Account) require.Len(t, accountStore.created, 2) var settlement *model.Account var created *model.Account for _, acc := range accountStore.created { if acc.IsSettlement { settlement = acc } if !acc.IsSettlement { created = acc } } require.NotNil(t, settlement) require.NotNil(t, created) parts := strings.Split(created.AccountCode, ":") require.Len(t, parts, 3) require.Equal(t, "liability", parts[0]) require.Equal(t, "usd", parts[1]) require.Len(t, parts[2], 24) require.Equal(t, defaultSettlementAccountCode("USD"), settlement.AccountCode) require.Equal(t, model.AccountTypeAsset, settlement.AccountType) require.Equal(t, "USD", settlement.Currency) require.True(t, settlement.AllowNegative) } func TestCreateAccountResponder_RetriesOnConflict(t *testing.T) { t.Parallel() orgRef := primitive.NewObjectID() accountStore := &accountStoreStub{ createErrs: []error{merrors.DataConflict("duplicate")}, } svc := &Service{ logger: zap.NewNop(), storage: &repositoryStub{accounts: accountStore}, } req := &ledgerv1.CreateAccountRequest{ OrganizationRef: orgRef.Hex(), AccountType: ledgerv1.AccountType_ACCOUNT_TYPE_ASSET, Currency: "usd", } resp, err := svc.createAccountResponder(context.Background(), req)(context.Background()) require.NoError(t, err) require.NotNil(t, resp) require.NotNil(t, resp.Account) require.Len(t, accountStore.created, 2) var created *model.Account for _, acc := range accountStore.created { if !acc.IsSettlement { created = acc } } require.NotNil(t, created) require.Equal(t, created.AccountCode, resp.Account.AccountCode) } func TestCreateAccountResponder_InvalidAccountType(t *testing.T) { t.Parallel() svc := &Service{ logger: zap.NewNop(), storage: &repositoryStub{accounts: &accountStoreStub{}}, } req := &ledgerv1.CreateAccountRequest{ OrganizationRef: primitive.NewObjectID().Hex(), Currency: "USD", } _, err := svc.createAccountResponder(context.Background(), req)(context.Background()) require.Error(t, err) }