improved tgsettle messages + storage fixes

This commit is contained in:
Stephan D
2026-03-05 11:54:07 +01:00
parent 801f349aa8
commit 5e59fea7e5
16 changed files with 537 additions and 172 deletions

View File

@@ -55,8 +55,9 @@ func (s *Service) sweepExpiredConfirmations(ctx context.Context) {
s.logger.Warn("Failed to list expired pending confirmations", zap.Error(err)) s.logger.Warn("Failed to list expired pending confirmations", zap.Error(err))
return return
} }
for _, pending := range expired { for i := range expired {
if pending == nil || strings.TrimSpace(pending.RequestID) == "" { pending := &expired[i]
if strings.TrimSpace(pending.RequestID) == "" {
continue continue
} }
result := &model.ConfirmationResult{ result := &model.ConfirmationResult{

View File

@@ -162,19 +162,18 @@ func (f *fakePendingStore) DeleteByRequestID(_ context.Context, requestID string
return nil return nil
} }
func (f *fakePendingStore) ListExpired(_ context.Context, now time.Time, limit int64) ([]*storagemodel.PendingConfirmation, error) { func (f *fakePendingStore) ListExpired(_ context.Context, now time.Time, limit int64) ([]storagemodel.PendingConfirmation, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if limit <= 0 { if limit <= 0 {
limit = 100 limit = 100
} }
result := make([]*storagemodel.PendingConfirmation, 0) result := make([]storagemodel.PendingConfirmation, 0)
for _, record := range f.records { for _, record := range f.records {
if record == nil || record.ExpiresAt.IsZero() || record.ExpiresAt.After(now) { if record == nil || record.ExpiresAt.IsZero() || record.ExpiresAt.After(now) {
continue continue
} }
cp := *record result = append(result, *record)
result = append(result, &cp)
if int64(len(result)) >= limit { if int64(len(result)) >= limit {
break break
} }

View File

@@ -6,6 +6,7 @@ type Command string
const ( const (
CommandStart Command = "start" CommandStart Command = "start"
CommandHelp Command = "help"
CommandFund Command = "fund" CommandFund Command = "fund"
CommandWithdraw Command = "withdraw" CommandWithdraw Command = "withdraw"
CommandConfirm Command = "confirm" CommandConfirm Command = "confirm"
@@ -14,6 +15,7 @@ const (
var supportedCommands = []Command{ var supportedCommands = []Command{
CommandStart, CommandStart,
CommandHelp,
CommandFund, CommandFund,
CommandWithdraw, CommandWithdraw,
CommandConfirm, CommandConfirm,
@@ -56,3 +58,29 @@ func supportedCommandsMessage() string {
func confirmationCommandsMessage() string { func confirmationCommandsMessage() string {
return "Confirm operation?\n\n" + CommandConfirm.Slash() + "\n" + CommandCancel.Slash() return "Confirm operation?\n\n" + CommandConfirm.Slash() + "\n" + CommandCancel.Slash()
} }
func helpMessage(accountCode string, currency string) string {
accountCode = strings.TrimSpace(accountCode)
currency = strings.ToUpper(strings.TrimSpace(currency))
if accountCode == "" {
accountCode = "N/A"
}
if currency == "" {
currency = "N/A"
}
lines := []string{
"Treasury bot help",
"",
"Attached account: " + accountCode + " (" + currency + ")",
"",
"How to use:",
"1) Start funding with " + CommandFund.Slash() + " or withdrawal with " + CommandWithdraw.Slash(),
"2) Enter amount as decimal, dot separator, no currency (example: 1250.75)",
"3) Confirm with " + CommandConfirm.Slash() + " or abort with " + CommandCancel.Slash(),
"",
"After confirmation there is a cooldown window. You can cancel during it with " + CommandCancel.Slash() + ".",
"You will receive a follow-up message with execution success or failure.",
}
return strings.Join(lines, "\n")
}

View File

@@ -17,7 +17,7 @@ import (
const unauthorizedMessage = "Sorry, your Telegram account is not authorized to perform treasury operations." const unauthorizedMessage = "Sorry, your Telegram account is not authorized to perform treasury operations."
const unauthorizedChatMessage = "Sorry, this Telegram chat is not authorized to perform treasury operations." const unauthorizedChatMessage = "Sorry, this Telegram chat is not authorized to perform treasury operations."
var welcomeMessage = "Welcome to tgsettle treasury bot.\n\nUse " + CommandFund.Slash() + " to credit your account and " + CommandWithdraw.Slash() + " to debit it.\nAfter entering an amount, use " + CommandConfirm.Slash() + " or " + CommandCancel.Slash() + "." const amountInputHint = "Enter amount as a decimal number using a dot separator and without currency.\nExample: 1250.75"
type SendTextFunc func(ctx context.Context, chatID string, text string) error type SendTextFunc func(ctx context.Context, chatID string, text string) error
@@ -26,6 +26,12 @@ type ScheduleTracker interface {
Untrack(requestID string) Untrack(requestID string)
} }
type AccountProfile struct {
AccountID string
AccountCode string
Currency string
}
type CreateRequestInput struct { type CreateRequestInput struct {
OperationType storagemodel.TreasuryOperationType OperationType storagemodel.TreasuryOperationType
TelegramUserID string TelegramUserID string
@@ -39,6 +45,7 @@ type TreasuryService interface {
MaxPerOperationLimit() string MaxPerOperationLimit() string
GetActiveRequestForAccount(ctx context.Context, ledgerAccountID string) (*storagemodel.TreasuryRequest, error) GetActiveRequestForAccount(ctx context.Context, ledgerAccountID string) (*storagemodel.TreasuryRequest, error)
GetAccountProfile(ctx context.Context, ledgerAccountID string) (*AccountProfile, error)
CreateRequest(ctx context.Context, input CreateRequestInput) (*storagemodel.TreasuryRequest, error) CreateRequest(ctx context.Context, input CreateRequestInput) (*storagemodel.TreasuryRequest, error)
ConfirmRequest(ctx context.Context, requestID string, telegramUserID string) (*storagemodel.TreasuryRequest, error) ConfirmRequest(ctx context.Context, requestID string, telegramUserID string) (*storagemodel.TreasuryRequest, error)
CancelRequest(ctx context.Context, requestID string, telegramUserID string) (*storagemodel.TreasuryRequest, error) CancelRequest(ctx context.Context, requestID string, telegramUserID string) (*storagemodel.TreasuryRequest, error)
@@ -119,6 +126,18 @@ func (r *Router) HandleUpdate(ctx context.Context, update *model.TelegramWebhook
if chatID == "" || userID == "" { if chatID == "" || userID == "" {
return false return false
} }
command := parseCommand(text)
if r.logger != nil {
r.logger.Debug("Telegram treasury update received",
zap.Int64("update_id", update.UpdateID),
zap.String("chat_id", chatID),
zap.String("telegram_user_id", userID),
zap.String("command", strings.TrimSpace(string(command))),
zap.String("message_text", text),
zap.String("reply_to_message_id", strings.TrimSpace(message.ReplyToMessageID)),
)
}
if !r.allowAnyChat { if !r.allowAnyChat {
if _, ok := r.allowedChats[chatID]; !ok { if _, ok := r.allowedChats[chatID]; !ok {
r.logUnauthorized(update) r.logUnauthorized(update)
@@ -134,21 +153,49 @@ func (r *Router) HandleUpdate(ctx context.Context, update *model.TelegramWebhook
return true return true
} }
command := parseCommand(text)
switch command { switch command {
case CommandStart: case CommandStart:
_ = r.sendText(ctx, chatID, welcomeMessage) profile := r.resolveAccountProfile(ctx, accountID)
_ = r.sendText(ctx, chatID, welcomeMessage(profile))
return true
case CommandHelp:
profile := r.resolveAccountProfile(ctx, accountID)
_ = r.sendText(ctx, chatID, helpMessage(displayAccountCode(profile, accountID), profile.Currency))
return true return true
case CommandFund: case CommandFund:
if r.logger != nil {
r.logger.Info("Treasury funding dialog requested",
zap.String("chat_id", chatID),
zap.String("telegram_user_id", userID),
zap.String("ledger_account_id", accountID))
}
r.startAmountDialog(ctx, userID, accountID, chatID, storagemodel.TreasuryOperationFund) r.startAmountDialog(ctx, userID, accountID, chatID, storagemodel.TreasuryOperationFund)
return true return true
case CommandWithdraw: case CommandWithdraw:
if r.logger != nil {
r.logger.Info("Treasury withdrawal dialog requested",
zap.String("chat_id", chatID),
zap.String("telegram_user_id", userID),
zap.String("ledger_account_id", accountID))
}
r.startAmountDialog(ctx, userID, accountID, chatID, storagemodel.TreasuryOperationWithdraw) r.startAmountDialog(ctx, userID, accountID, chatID, storagemodel.TreasuryOperationWithdraw)
return true return true
case CommandConfirm: case CommandConfirm:
if r.logger != nil {
r.logger.Info("Treasury confirmation requested",
zap.String("chat_id", chatID),
zap.String("telegram_user_id", userID),
zap.String("ledger_account_id", accountID))
}
r.confirm(ctx, userID, accountID, chatID) r.confirm(ctx, userID, accountID, chatID)
return true return true
case CommandCancel: case CommandCancel:
if r.logger != nil {
r.logger.Info("Treasury cancellation requested",
zap.String("chat_id", chatID),
zap.String("telegram_user_id", userID),
zap.String("ledger_account_id", accountID))
}
r.cancel(ctx, userID, accountID, chatID) r.cancel(ctx, userID, accountID, chatID)
return true return true
} }
@@ -182,7 +229,10 @@ func (r *Router) HandleUpdate(ctx context.Context, update *model.TelegramWebhook
func (r *Router) startAmountDialog(ctx context.Context, userID, accountID, chatID string, operation storagemodel.TreasuryOperationType) { func (r *Router) startAmountDialog(ctx context.Context, userID, accountID, chatID string, operation storagemodel.TreasuryOperationType) {
active, err := r.service.GetActiveRequestForAccount(ctx, accountID) active, err := r.service.GetActiveRequestForAccount(ctx, accountID)
if err != nil { if err != nil {
if r.logger != nil {
r.logger.Warn("Failed to check active treasury request", zap.Error(err), zap.String("telegram_user_id", userID), zap.String("ledger_account_id", accountID)) r.logger.Warn("Failed to check active treasury request", zap.Error(err), zap.String("telegram_user_id", userID), zap.String("ledger_account_id", accountID))
}
_ = r.sendText(ctx, chatID, "Unable to check pending treasury operations right now. Please try again.")
return return
} }
if active != nil { if active != nil {
@@ -199,7 +249,8 @@ func (r *Router) startAmountDialog(ctx context.Context, userID, accountID, chatI
OperationType: operation, OperationType: operation,
LedgerAccountID: accountID, LedgerAccountID: accountID,
}) })
_ = r.sendText(ctx, chatID, "Enter amount:") profile := r.resolveAccountProfile(ctx, accountID)
_ = r.sendText(ctx, chatID, amountPromptMessage(operation, profile, accountID))
} }
func (r *Router) captureAmount(ctx context.Context, userID, accountID, chatID string, operation storagemodel.TreasuryOperationType, amount string) { func (r *Router) captureAmount(ctx context.Context, userID, accountID, chatID string, operation storagemodel.TreasuryOperationType, amount string) {
@@ -231,7 +282,7 @@ func (r *Router) captureAmount(ctx context.Context, userID, accountID, chatID st
} }
} }
if errors.Is(err, merrors.ErrInvalidArg) { if errors.Is(err, merrors.ErrInvalidArg) {
_ = r.sendText(ctx, chatID, "Invalid amount.\n\nEnter another amount or "+CommandCancel.Slash()) _ = r.sendText(ctx, chatID, "Invalid amount.\n\n"+amountInputHint+"\n\nEnter another amount or "+CommandCancel.Slash())
return return
} }
_ = r.sendText(ctx, chatID, "Failed to create treasury request.\n\nEnter another amount or "+CommandCancel.Slash()) _ = r.sendText(ctx, chatID, "Failed to create treasury request.\n\nEnter another amount or "+CommandCancel.Slash())
@@ -276,7 +327,7 @@ func (r *Router) confirm(ctx context.Context, userID string, accountID string, c
if delay < 0 { if delay < 0 {
delay = 0 delay = 0
} }
_ = r.sendText(ctx, chatID, "Operation confirmed.\n\nExecution scheduled in "+formatSeconds(delay)+".\n\nRequest ID: "+strings.TrimSpace(record.RequestID)) _ = r.sendText(ctx, chatID, "Operation confirmed.\n\nExecution scheduled in "+formatSeconds(delay)+".\nYou can cancel during this cooldown with "+CommandCancel.Slash()+".\n\nYou will receive a follow-up message with execution success or failure.\n\nRequest ID: "+strings.TrimSpace(record.RequestID))
} }
func (r *Router) cancel(ctx context.Context, userID string, accountID string, chatID string) { func (r *Router) cancel(ctx context.Context, userID string, accountID string, chatID string) {
@@ -315,7 +366,16 @@ func (r *Router) sendText(ctx context.Context, chatID string, text string) error
if chatID == "" || text == "" { if chatID == "" || text == "" {
return nil return nil
} }
return r.send(ctx, chatID, text) if err := r.send(ctx, chatID, text); err != nil {
if r.logger != nil {
r.logger.Warn("Failed to send treasury bot response",
zap.Error(err),
zap.String("chat_id", chatID),
zap.String("message_text", text))
}
return err
}
return nil
} }
func (r *Router) logUnauthorized(update *model.TelegramWebhookUpdate) { func (r *Router) logUnauthorized(update *model.TelegramWebhookUpdate) {
@@ -337,6 +397,7 @@ func pendingRequestMessage(record *storagemodel.TreasuryRequest) string {
return "You already have a pending treasury operation.\n\n" + CommandCancel.Slash() return "You already have a pending treasury operation.\n\n" + CommandCancel.Slash()
} }
return "You already have a pending treasury operation.\n\n" + return "You already have a pending treasury operation.\n\n" +
"Account: " + requestAccountDisplay(record) + "\n" +
"Request ID: " + strings.TrimSpace(record.RequestID) + "\n" + "Request ID: " + strings.TrimSpace(record.RequestID) + "\n" +
"Status: " + strings.TrimSpace(string(record.Status)) + "\n" + "Status: " + strings.TrimSpace(string(record.Status)) + "\n" +
"Amount: " + strings.TrimSpace(record.Amount) + " " + strings.TrimSpace(record.Currency) + "\n\n" + "Amount: " + strings.TrimSpace(record.Amount) + " " + strings.TrimSpace(record.Currency) + "\n\n" +
@@ -352,11 +413,89 @@ func confirmationPrompt(record *storagemodel.TreasuryRequest) string {
title = "Withdrawal request created." title = "Withdrawal request created."
} }
return title + "\n\n" + return title + "\n\n" +
"Account: " + strings.TrimSpace(record.LedgerAccountID) + "\n" + "Account: " + requestAccountDisplay(record) + "\n" +
"Amount: " + strings.TrimSpace(record.Amount) + " " + strings.TrimSpace(record.Currency) + "\n\n" + "Amount: " + strings.TrimSpace(record.Amount) + " " + strings.TrimSpace(record.Currency) + "\n\n" +
confirmationCommandsMessage() confirmationCommandsMessage()
} }
func welcomeMessage(profile *AccountProfile) string {
accountCode := displayAccountCode(profile, "")
currency := ""
if profile != nil {
currency = strings.ToUpper(strings.TrimSpace(profile.Currency))
}
if accountCode == "" {
accountCode = "N/A"
}
if currency == "" {
currency = "N/A"
}
return "Welcome to Sendico treasury bot.\n\nAttached account: " + accountCode + " (" + currency + ").\nUse " + CommandFund.Slash() + " to credit your account and " + CommandWithdraw.Slash() + " to debit it.\nAfter entering an amount, use " + CommandConfirm.Slash() + " or " + CommandCancel.Slash() + ".\nUse " + CommandHelp.Slash() + " for detailed usage."
}
func amountPromptMessage(operation storagemodel.TreasuryOperationType, profile *AccountProfile, fallbackAccountID string) string {
action := "fund"
if operation == storagemodel.TreasuryOperationWithdraw {
action = "withdraw"
}
accountCode := displayAccountCode(profile, fallbackAccountID)
currency := ""
if profile != nil {
currency = strings.ToUpper(strings.TrimSpace(profile.Currency))
}
if accountCode == "" {
accountCode = "N/A"
}
if currency == "" {
currency = "N/A"
}
return "Preparing to " + action + " account " + accountCode + " (" + currency + ").\n\n" + amountInputHint
}
func requestAccountDisplay(record *storagemodel.TreasuryRequest) string {
if record == nil {
return ""
}
if code := strings.TrimSpace(record.LedgerAccountCode); code != "" {
return code
}
return strings.TrimSpace(record.LedgerAccountID)
}
func displayAccountCode(profile *AccountProfile, fallbackAccountID string) string {
if profile != nil {
if code := strings.TrimSpace(profile.AccountCode); code != "" {
return code
}
if id := strings.TrimSpace(profile.AccountID); id != "" {
return id
}
}
return strings.TrimSpace(fallbackAccountID)
}
func (r *Router) resolveAccountProfile(ctx context.Context, ledgerAccountID string) *AccountProfile {
if r == nil || r.service == nil {
return &AccountProfile{AccountID: strings.TrimSpace(ledgerAccountID)}
}
profile, err := r.service.GetAccountProfile(ctx, ledgerAccountID)
if err != nil {
if r.logger != nil {
r.logger.Warn("Failed to resolve treasury account profile",
zap.Error(err),
zap.String("ledger_account_id", strings.TrimSpace(ledgerAccountID)))
}
return &AccountProfile{AccountID: strings.TrimSpace(ledgerAccountID)}
}
if profile == nil {
return &AccountProfile{AccountID: strings.TrimSpace(ledgerAccountID)}
}
if strings.TrimSpace(profile.AccountID) == "" {
profile.AccountID = strings.TrimSpace(ledgerAccountID)
}
return profile
}
func formatSeconds(value int64) string { func formatSeconds(value int64) string {
if value == 1 { if value == 1 {
return "1 second" return "1 second"

View File

@@ -24,6 +24,14 @@ func (fakeService) GetActiveRequestForAccount(context.Context, string) (*storage
return nil, nil return nil, nil
} }
func (fakeService) GetAccountProfile(_ context.Context, ledgerAccountID string) (*AccountProfile, error) {
return &AccountProfile{
AccountID: ledgerAccountID,
AccountCode: ledgerAccountID,
Currency: "USD",
}, nil
}
func (fakeService) CreateRequest(context.Context, CreateRequestInput) (*storagemodel.TreasuryRequest, error) { func (fakeService) CreateRequest(context.Context, CreateRequestInput) (*storagemodel.TreasuryRequest, error) {
return nil, nil return nil, nil
} }
@@ -124,7 +132,11 @@ func TestRouterEmptyAllowedChats_AllowsAnyChatForAuthorizedUser(t *testing.T) {
if len(sent) != 1 { if len(sent) != 1 {
t.Fatalf("expected one message, got %d", len(sent)) t.Fatalf("expected one message, got %d", len(sent))
} }
if sent[0] != "Enter amount:" { if sent[0] != amountPromptMessage(
storagemodel.TreasuryOperationFund,
&AccountProfile{AccountID: "acct-1", AccountCode: "acct-1", Currency: "USD"},
"acct-1",
) {
t.Fatalf("unexpected message: %q", sent[0]) t.Fatalf("unexpected message: %q", sent[0])
} }
} }
@@ -186,7 +198,38 @@ func TestRouterStartAuthorizedShowsWelcome(t *testing.T) {
if len(sent) != 1 { if len(sent) != 1 {
t.Fatalf("expected one message, got %d", len(sent)) t.Fatalf("expected one message, got %d", len(sent))
} }
if sent[0] != welcomeMessage { if sent[0] != welcomeMessage(&AccountProfile{AccountID: "acct-1", AccountCode: "acct-1", Currency: "USD"}) {
t.Fatalf("unexpected message: %q", sent[0])
}
}
func TestRouterHelpAuthorizedShowsHelp(t *testing.T) {
var sent []string
router := NewRouter(
mloggerfactory.NewLogger(false),
fakeService{},
func(_ context.Context, _ string, text string) error {
sent = append(sent, text)
return nil
},
nil,
nil,
map[string]string{"123": "acct-1"},
)
handled := router.HandleUpdate(context.Background(), &model.TelegramWebhookUpdate{
Message: &model.TelegramMessage{
ChatID: "777",
FromUserID: "123",
Text: "/help",
},
})
if !handled {
t.Fatalf("expected update to be handled")
}
if len(sent) != 1 {
t.Fatalf("expected one message, got %d", len(sent))
}
if sent[0] != helpMessage("acct-1", "USD") {
t.Fatalf("unexpected message: %q", sent[0]) t.Fatalf("unexpected message: %q", sent[0])
} }
} }

View File

@@ -28,6 +28,7 @@ type Config struct {
type Account struct { type Account struct {
AccountID string AccountID string
AccountCode string
Currency string Currency string
OrganizationRef string OrganizationRef string
} }
@@ -130,14 +131,20 @@ func (c *connectorClient) GetAccount(ctx context.Context, accountID string) (*Ac
if account == nil { if account == nil {
return nil, merrors.NoData("ledger account not found") return nil, merrors.NoData("ledger account not found")
} }
accountCode := strings.TrimSpace(account.GetLabel())
organizationRef := strings.TrimSpace(account.GetOwnerRef()) organizationRef := strings.TrimSpace(account.GetOwnerRef())
if organizationRef == "" && account.GetProviderDetails() != nil { if organizationRef == "" && account.GetProviderDetails() != nil {
if value, ok := account.GetProviderDetails().AsMap()["organization_ref"]; ok { details := account.GetProviderDetails().AsMap()
organizationRef = strings.TrimSpace(fmt.Sprint(value)) if organizationRef == "" {
organizationRef = firstDetailValue(details, "organization_ref", "organizationRef", "org_ref")
}
if accountCode == "" {
accountCode = firstDetailValue(details, "account_code", "accountCode", "code", "ledger_account_code")
} }
} }
return &Account{ return &Account{
AccountID: accountID, AccountID: accountID,
AccountCode: accountCode,
Currency: strings.ToUpper(strings.TrimSpace(account.GetAsset())), Currency: strings.ToUpper(strings.TrimSpace(account.GetAsset())),
OrganizationRef: organizationRef, OrganizationRef: organizationRef,
}, nil }, nil
@@ -285,3 +292,21 @@ func normalizeEndpoint(raw string) (string, bool) {
return raw, false return raw, false
} }
} }
func firstDetailValue(values map[string]any, keys ...string) string {
if len(values) == 0 || len(keys) == 0 {
return ""
}
for _, key := range keys {
key = strings.TrimSpace(key)
if key == "" {
continue
}
if value, ok := values[key]; ok {
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
return text
}
}
}
return ""
}

View File

@@ -120,6 +120,24 @@ func (a *botServiceAdapter) GetActiveRequestForAccount(ctx context.Context, ledg
return a.svc.GetActiveRequestForAccount(ctx, ledgerAccountID) return a.svc.GetActiveRequestForAccount(ctx, ledgerAccountID)
} }
func (a *botServiceAdapter) GetAccountProfile(ctx context.Context, ledgerAccountID string) (*bot.AccountProfile, error) {
if a == nil || a.svc == nil {
return nil, merrors.Internal("treasury service unavailable")
}
profile, err := a.svc.GetAccountProfile(ctx, ledgerAccountID)
if err != nil {
return nil, err
}
if profile == nil {
return nil, nil
}
return &bot.AccountProfile{
AccountID: strings.TrimSpace(profile.AccountID),
AccountCode: strings.TrimSpace(profile.AccountCode),
Currency: strings.TrimSpace(profile.Currency),
}, nil
}
func (a *botServiceAdapter) CreateRequest(ctx context.Context, input bot.CreateRequestInput) (*storagemodel.TreasuryRequest, error) { func (a *botServiceAdapter) CreateRequest(ctx context.Context, input bot.CreateRequestInput) (*storagemodel.TreasuryRequest, error) {
if a == nil || a.svc == nil { if a == nil || a.svc == nil {
return nil, merrors.Internal("treasury service unavailable") return nil, merrors.Internal("treasury service unavailable")

View File

@@ -145,7 +145,7 @@ func (s *Scheduler) hydrateTimers(ctx context.Context) {
return return
} }
for _, record := range scheduled { for _, record := range scheduled {
s.TrackScheduled(record) s.TrackScheduled(&record)
} }
} }
@@ -200,17 +200,53 @@ func (s *Scheduler) executeAndNotifyByID(ctx context.Context, requestID string)
s.logger.Warn("Failed to execute treasury request", zap.Error(err), zap.String("request_id", requestID)) s.logger.Warn("Failed to execute treasury request", zap.Error(err), zap.String("request_id", requestID))
return return
} }
if result == nil || result.Request == nil || s.notify == nil { if result == nil || result.Request == nil {
s.logger.Debug("Treasury execution produced no result", zap.String("request_id", requestID))
return
}
if s.notify == nil {
s.logger.Warn("Treasury execution notifier is unavailable", zap.String("request_id", requestID))
return return
} }
text := executionMessage(result) text := executionMessage(result)
if strings.TrimSpace(text) == "" { if strings.TrimSpace(text) == "" {
s.logger.Debug("Treasury execution result has no notification text",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
return return
} }
if err := s.notify(ctx, strings.TrimSpace(result.Request.ChatID), text); err != nil { chatID := strings.TrimSpace(result.Request.ChatID)
s.logger.Warn("Failed to notify treasury execution result", zap.Error(err), zap.String("request_id", strings.TrimSpace(result.Request.RequestID))) if chatID == "" {
s.logger.Warn("Treasury execution notification skipped: empty chat_id",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)))
return
} }
s.logger.Info("Sending treasury execution notification",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
notifyCtx := context.Background()
if ctx != nil {
notifyCtx = ctx
}
notifyCtx, notifyCancel := context.WithTimeout(notifyCtx, 15*time.Second)
defer notifyCancel()
if err := s.notify(notifyCtx, chatID, text); err != nil {
s.logger.Warn("Failed to notify treasury execution result",
zap.Error(err),
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
return
}
s.logger.Info("Treasury execution notification sent",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
} }
func executionMessage(result *ExecutionResult) string { func executionMessage(result *ExecutionResult) string {
@@ -237,7 +273,7 @@ func executionMessage(result *ExecutionResult) string {
} }
} }
return op + " completed.\n\n" + return op + " completed.\n\n" +
"Account: " + strings.TrimSpace(request.LedgerAccountID) + "\n" + "Account: " + requestAccountCode(request) + "\n" +
"Amount: " + sign + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" + "Amount: " + sign + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" +
"New balance: " + balanceAmount + " " + balanceCurrency + "\n\n" + "New balance: " + balanceAmount + " " + balanceCurrency + "\n\n" +
"Reference: " + strings.TrimSpace(request.RequestID) "Reference: " + strings.TrimSpace(request.RequestID)
@@ -250,7 +286,7 @@ func executionMessage(result *ExecutionResult) string {
reason = "Unknown error." reason = "Unknown error."
} }
return "Execution failed.\n\n" + return "Execution failed.\n\n" +
"Account: " + strings.TrimSpace(request.LedgerAccountID) + "\n" + "Account: " + requestAccountCode(request) + "\n" +
"Amount: " + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" + "Amount: " + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" +
"Status: FAILED\n\n" + "Status: FAILED\n\n" +
"Reason:\n" + reason + "\n\n" + "Reason:\n" + reason + "\n\n" +
@@ -259,3 +295,13 @@ func executionMessage(result *ExecutionResult) string {
return "" return ""
} }
} }
func requestAccountCode(request *storagemodel.TreasuryRequest) string {
if request == nil {
return ""
}
if code := strings.TrimSpace(request.LedgerAccountCode); code != "" {
return code
}
return strings.TrimSpace(request.LedgerAccountID)
}

View File

@@ -27,6 +27,12 @@ type CreateRequestInput struct {
Amount string Amount string
} }
type AccountProfile struct {
AccountID string
AccountCode string
Currency string
}
type ExecutionResult struct { type ExecutionResult struct {
Request *storagemodel.TreasuryRequest Request *storagemodel.TreasuryRequest
NewBalance *ledger.Balance NewBalance *ledger.Balance
@@ -103,6 +109,29 @@ func (s *Service) GetRequest(ctx context.Context, requestID string) (*storagemod
return s.repo.FindByRequestID(ctx, requestID) return s.repo.FindByRequestID(ctx, requestID)
} }
func (s *Service) GetAccountProfile(ctx context.Context, ledgerAccountID string) (*AccountProfile, error) {
if s == nil || s.ledger == nil {
return nil, merrors.Internal("treasury service unavailable")
}
ledgerAccountID = strings.TrimSpace(ledgerAccountID)
if ledgerAccountID == "" {
return nil, merrors.InvalidArgument("ledger_account_id is required", "ledger_account_id")
}
account, err := s.ledger.GetAccount(ctx, ledgerAccountID)
if err != nil {
return nil, err
}
if account == nil {
return nil, merrors.NoData("ledger account not found")
}
return &AccountProfile{
AccountID: ledgerAccountID,
AccountCode: resolveAccountCode(account, ledgerAccountID),
Currency: strings.ToUpper(strings.TrimSpace(account.Currency)),
}, nil
}
func (s *Service) CreateRequest(ctx context.Context, input CreateRequestInput) (*storagemodel.TreasuryRequest, error) { func (s *Service) CreateRequest(ctx context.Context, input CreateRequestInput) (*storagemodel.TreasuryRequest, error) {
if s == nil || s.repo == nil || s.ledger == nil || s.validator == nil { if s == nil || s.repo == nil || s.ledger == nil || s.validator == nil {
return nil, merrors.Internal("treasury service unavailable") return nil, merrors.Internal("treasury service unavailable")
@@ -160,6 +189,7 @@ func (s *Service) CreateRequest(ctx context.Context, input CreateRequestInput) (
OperationType: input.OperationType, OperationType: input.OperationType,
TelegramUserID: input.TelegramUserID, TelegramUserID: input.TelegramUserID,
LedgerAccountID: input.LedgerAccountID, LedgerAccountID: input.LedgerAccountID,
LedgerAccountCode: resolveAccountCode(account, input.LedgerAccountID),
OrganizationRef: account.OrganizationRef, OrganizationRef: account.OrganizationRef,
ChatID: input.ChatID, ChatID: input.ChatID,
Amount: normalizedAmount, Amount: normalizedAmount,
@@ -364,14 +394,14 @@ func (s *Service) executeClaimed(ctx context.Context, record *storagemodel.Treas
}, nil }, nil
} }
func (s *Service) DueRequests(ctx context.Context, statuses []storagemodel.TreasuryRequestStatus, now time.Time, limit int64) ([]*storagemodel.TreasuryRequest, error) { func (s *Service) DueRequests(ctx context.Context, statuses []storagemodel.TreasuryRequestStatus, now time.Time, limit int64) ([]storagemodel.TreasuryRequest, error) {
if s == nil || s.repo == nil { if s == nil || s.repo == nil {
return nil, merrors.Internal("treasury service unavailable") return nil, merrors.Internal("treasury service unavailable")
} }
return s.repo.FindDueByStatus(ctx, statuses, now, limit) return s.repo.FindDueByStatus(ctx, statuses, now, limit)
} }
func (s *Service) ScheduledRequests(ctx context.Context, limit int64) ([]*storagemodel.TreasuryRequest, error) { func (s *Service) ScheduledRequests(ctx context.Context, limit int64) ([]storagemodel.TreasuryRequest, error) {
if s == nil || s.repo == nil { if s == nil || s.repo == nil {
return nil, merrors.Internal("treasury service unavailable") return nil, merrors.Internal("treasury service unavailable")
} }
@@ -395,10 +425,14 @@ func (s *Service) logRequest(record *storagemodel.TreasuryRequest, status string
zap.String("request_id", strings.TrimSpace(record.RequestID)), zap.String("request_id", strings.TrimSpace(record.RequestID)),
zap.String("telegram_user_id", strings.TrimSpace(record.TelegramUserID)), zap.String("telegram_user_id", strings.TrimSpace(record.TelegramUserID)),
zap.String("ledger_account_id", strings.TrimSpace(record.LedgerAccountID)), zap.String("ledger_account_id", strings.TrimSpace(record.LedgerAccountID)),
zap.String("ledger_account_code", strings.TrimSpace(record.LedgerAccountCode)),
zap.String("chat_id", strings.TrimSpace(record.ChatID)),
zap.String("operation_type", strings.TrimSpace(string(record.OperationType))), zap.String("operation_type", strings.TrimSpace(string(record.OperationType))),
zap.String("amount", strings.TrimSpace(record.Amount)), zap.String("amount", strings.TrimSpace(record.Amount)),
zap.String("currency", strings.TrimSpace(record.Currency)), zap.String("currency", strings.TrimSpace(record.Currency)),
zap.String("status", status), zap.String("status", status),
zap.String("ledger_reference", strings.TrimSpace(record.LedgerReference)),
zap.String("error_message", strings.TrimSpace(record.ErrorMessage)),
} }
if err != nil { if err != nil {
fields = append(fields, zap.Error(err)) fields = append(fields, zap.Error(err))
@@ -409,3 +443,15 @@ func (s *Service) logRequest(record *storagemodel.TreasuryRequest, status string
func newRequestID() string { func newRequestID() string {
return "TGSETTLE-" + strings.ToUpper(bson.NewObjectID().Hex()[:8]) return "TGSETTLE-" + strings.ToUpper(bson.NewObjectID().Hex()[:8])
} }
func resolveAccountCode(account *ledger.Account, fallbackAccountID string) string {
if account != nil {
if code := strings.TrimSpace(account.AccountCode); code != "" {
return code
}
if code := strings.TrimSpace(account.AccountID); code != "" {
return code
}
}
return strings.TrimSpace(fallbackAccountID)
}

View File

@@ -143,9 +143,6 @@ func (v *Validator) ValidateDailyLimit(ctx context.Context, ledgerAccountID stri
} }
total := new(big.Rat) total := new(big.Rat)
for _, record := range records { for _, record := range records {
if record == nil {
continue
}
next, err := parseAmountRat(record.Amount) next, err := parseAmountRat(record.Amount)
if err != nil { if err != nil {
return merrors.Internal("treasury request amount is invalid") return merrors.Internal("treasury request amount is invalid")

View File

@@ -31,6 +31,7 @@ type TreasuryRequest struct {
OperationType TreasuryOperationType `bson:"operationType,omitempty" json:"operation_type,omitempty"` OperationType TreasuryOperationType `bson:"operationType,omitempty" json:"operation_type,omitempty"`
TelegramUserID string `bson:"telegramUserId,omitempty" json:"telegram_user_id,omitempty"` TelegramUserID string `bson:"telegramUserId,omitempty" json:"telegram_user_id,omitempty"`
LedgerAccountID string `bson:"ledgerAccountId,omitempty" json:"ledger_account_id,omitempty"` LedgerAccountID string `bson:"ledgerAccountId,omitempty" json:"ledger_account_id,omitempty"`
LedgerAccountCode string `bson:"ledgerAccountCode,omitempty" json:"ledger_account_code,omitempty"`
OrganizationRef string `bson:"organizationRef,omitempty" json:"organization_ref,omitempty"` OrganizationRef string `bson:"organizationRef,omitempty" json:"organization_ref,omitempty"`
ChatID string `bson:"chatId,omitempty" json:"chat_id,omitempty"` ChatID string `bson:"chatId,omitempty" json:"chat_id,omitempty"`
Amount string `bson:"amount,omitempty" json:"amount,omitempty"` Amount string `bson:"amount,omitempty" json:"amount,omitempty"`

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"strings" "strings"
"time"
"github.com/tech/sendico/gateway/tgsettle/storage" "github.com/tech/sendico/gateway/tgsettle/storage"
"github.com/tech/sendico/gateway/tgsettle/storage/model" "github.com/tech/sendico/gateway/tgsettle/storage/model"
@@ -12,7 +11,6 @@ import (
ri "github.com/tech/sendico/pkg/db/repository/index" ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -120,31 +118,26 @@ func (p *Payments) Upsert(ctx context.Context, record *model.PaymentRecord) erro
if record.IntentRef == "" { if record.IntentRef == "" {
return merrors.InvalidArgument("intention reference key is required", "intent_ref") return merrors.InvalidArgument("intention reference key is required", "intent_ref")
} }
now := time.Now()
if record.CreatedAt.IsZero() {
record.CreatedAt = now
}
record.UpdatedAt = now
record.ID = bson.NilObjectID
filter := repository.Filter(fieldIdempotencyKey, record.IdempotencyKey) filter := repository.Filter(fieldIdempotencyKey, record.IdempotencyKey)
existing := &model.PaymentRecord{} err := p.repo.Insert(ctx, record, filter)
err := p.repo.FindOneByFilter(ctx, filter, existing)
switch {
case err == nil:
record.ID = existing.ID
err = p.repo.Update(ctx, record)
case errors.Is(err, merrors.ErrNoData):
record.ID = bson.NilObjectID
err = p.repo.Insert(ctx, record, filter)
if errors.Is(err, merrors.ErrDataConflict) { if errors.Is(err, merrors.ErrDataConflict) {
if findErr := p.repo.FindOneByFilter(ctx, filter, existing); findErr != nil { patch := repository.Patch().
err = findErr Set(repository.Field(fieldOperationRef), record.OperationRef).
break Set(repository.Field("paymentIntentId"), record.PaymentIntentID).
} Set(repository.Field("quoteRef"), record.QuoteRef).
record.ID = existing.ID Set(repository.Field("intentRef"), record.IntentRef).
err = p.repo.Update(ctx, record) Set(repository.Field("paymentRef"), record.PaymentRef).
} Set(repository.Field("outgoingLeg"), record.OutgoingLeg).
Set(repository.Field("targetChatId"), record.TargetChatID).
Set(repository.Field("requestedMoney"), record.RequestedMoney).
Set(repository.Field("executedMoney"), record.ExecutedMoney).
Set(repository.Field("status"), record.Status).
Set(repository.Field("failureReason"), record.FailureReason).
Set(repository.Field("executedAt"), record.ExecutedAt).
Set(repository.Field("expiresAt"), record.ExpiresAt).
Set(repository.Field("expiredAt"), record.ExpiredAt)
_, err = p.repo.PatchMany(ctx, filter, patch)
} }
if err != nil { if err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {

View File

@@ -13,7 +13,7 @@ import (
ri "github.com/tech/sendico/pkg/db/repository/index" ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/v2/bson" mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -86,34 +86,19 @@ func (p *PendingConfirmations) Upsert(ctx context.Context, record *model.Pending
return merrors.InvalidArgument("expires_at is required", "expires_at") return merrors.InvalidArgument("expires_at is required", "expires_at")
} }
now := time.Now()
createdAt := record.CreatedAt
if createdAt.IsZero() {
createdAt = now
}
record.UpdatedAt = now
record.CreatedAt = createdAt
filter := repository.Filter(fieldPendingRequestID, record.RequestID) filter := repository.Filter(fieldPendingRequestID, record.RequestID)
existing := &model.PendingConfirmation{} err := p.repo.Insert(ctx, record, filter)
err := p.repo.FindOneByFilter(ctx, filter, existing)
switch {
case err == nil:
record.ID = existing.ID
record.CreatedAt = existing.CreatedAt
err = p.repo.Update(ctx, record)
case errors.Is(err, merrors.ErrNoData):
record.ID = bson.NilObjectID
err = p.repo.Insert(ctx, record, filter)
if errors.Is(err, merrors.ErrDataConflict) { if errors.Is(err, merrors.ErrDataConflict) {
if findErr := p.repo.FindOneByFilter(ctx, filter, existing); findErr != nil { patch := repository.Patch().
err = findErr Set(repository.Field(fieldPendingMessageID), record.MessageID).
break Set(repository.Field("targetChatId"), record.TargetChatID).
} Set(repository.Field("acceptedUserIds"), record.AcceptedUserIDs).
record.ID = existing.ID Set(repository.Field("requestedMoney"), record.RequestedMoney).
record.CreatedAt = existing.CreatedAt Set(repository.Field("sourceService"), record.SourceService).
err = p.repo.Update(ctx, record) Set(repository.Field("rail"), record.Rail).
} Set(repository.Field("clarified"), record.Clarified).
Set(repository.Field(fieldPendingExpiresAt), record.ExpiresAt)
_, err = p.repo.PatchMany(ctx, filter, patch)
} }
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
p.logger.Warn("Failed to upsert pending confirmation", zap.Error(err), zap.String("request_id", record.RequestID)) p.logger.Warn("Failed to upsert pending confirmation", zap.Error(err), zap.String("request_id", record.RequestID))
@@ -201,7 +186,7 @@ func (p *PendingConfirmations) DeleteByRequestID(ctx context.Context, requestID
return p.repo.DeleteMany(ctx, repository.Filter(fieldPendingRequestID, requestID)) return p.repo.DeleteMany(ctx, repository.Filter(fieldPendingRequestID, requestID))
} }
func (p *PendingConfirmations) ListExpired(ctx context.Context, now time.Time, limit int64) ([]*model.PendingConfirmation, error) { func (p *PendingConfirmations) ListExpired(ctx context.Context, now time.Time, limit int64) ([]model.PendingConfirmation, error) {
if limit <= 0 { if limit <= 0 {
limit = 100 limit = 100
} }
@@ -210,19 +195,11 @@ func (p *PendingConfirmations) ListExpired(ctx context.Context, now time.Time, l
Sort(repository.Field(fieldPendingExpiresAt), true). Sort(repository.Field(fieldPendingExpiresAt), true).
Limit(&limit) Limit(&limit)
result := make([]*model.PendingConfirmation, 0) items, err := mutil.GetObjects[model.PendingConfirmation](ctx, p.logger, query, nil, p.repo)
err := p.repo.FindManyByFilter(ctx, query, func(cur *mongo.Cursor) error {
next := &model.PendingConfirmation{}
if err := cur.Decode(next); err != nil {
return err
}
result = append(result, next)
return nil
})
if err != nil && !errors.Is(err, merrors.ErrNoData) { if err != nil && !errors.Is(err, merrors.ErrNoData) {
return nil, err return nil, err
} }
return result, nil return items, nil
} }
var _ storage.PendingConfirmationsStore = (*PendingConfirmations)(nil) var _ storage.PendingConfirmationsStore = (*PendingConfirmations)(nil)

View File

@@ -12,7 +12,6 @@ import (
ri "github.com/tech/sendico/pkg/db/repository/index" ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -67,24 +66,14 @@ func (t *TelegramConfirmations) Upsert(ctx context.Context, record *model.Telegr
record.ReceivedAt = time.Now() record.ReceivedAt = time.Now()
} }
filter := repository.Filter(fieldRequestID, record.RequestID) filter := repository.Filter(fieldRequestID, record.RequestID)
existing := &model.TelegramConfirmation{} err := t.repo.Insert(ctx, record, filter)
err := t.repo.FindOneByFilter(ctx, filter, existing)
switch {
case err == nil:
record.ID = existing.ID
err = t.repo.Update(ctx, record)
case errors.Is(err, merrors.ErrNoData):
record.ID = bson.NilObjectID
err = t.repo.Insert(ctx, record, filter)
if errors.Is(err, merrors.ErrDataConflict) { if errors.Is(err, merrors.ErrDataConflict) {
if findErr := t.repo.FindOneByFilter(ctx, filter, existing); findErr != nil { patch := repository.Patch().
err = findErr Set(repository.Field("paymentIntentId"), record.PaymentIntentID).
break Set(repository.Field("quoteRef"), record.QuoteRef).
} Set(repository.Field("rawReply"), record.RawReply).
record.ID = existing.ID Set(repository.Field("receivedAt"), record.ReceivedAt)
err = t.repo.Update(ctx, record) _, err = t.repo.PatchMany(ctx, filter, patch)
}
} }
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
fields := []zap.Field{zap.String("request_id", record.RequestID)} fields := []zap.Field{zap.String("request_id", record.RequestID)}

View File

@@ -13,7 +13,7 @@ import (
ri "github.com/tech/sendico/pkg/db/repository/index" ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/mlogger"
"go.mongodb.org/mongo-driver/v2/bson" mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -104,6 +104,7 @@ func (t *TreasuryRequests) Create(ctx context.Context, record *model.TreasuryReq
record.RequestID = strings.TrimSpace(record.RequestID) record.RequestID = strings.TrimSpace(record.RequestID)
record.TelegramUserID = strings.TrimSpace(record.TelegramUserID) record.TelegramUserID = strings.TrimSpace(record.TelegramUserID)
record.LedgerAccountID = strings.TrimSpace(record.LedgerAccountID) record.LedgerAccountID = strings.TrimSpace(record.LedgerAccountID)
record.LedgerAccountCode = strings.TrimSpace(record.LedgerAccountCode)
record.OrganizationRef = strings.TrimSpace(record.OrganizationRef) record.OrganizationRef = strings.TrimSpace(record.OrganizationRef)
record.ChatID = strings.TrimSpace(record.ChatID) record.ChatID = strings.TrimSpace(record.ChatID)
record.Amount = strings.TrimSpace(record.Amount) record.Amount = strings.TrimSpace(record.Amount)
@@ -134,20 +135,24 @@ func (t *TreasuryRequests) Create(ctx context.Context, record *model.TreasuryReq
return merrors.InvalidArgument("status is required", "status") return merrors.InvalidArgument("status is required", "status")
} }
now := time.Now()
if record.CreatedAt.IsZero() {
record.CreatedAt = now
}
record.UpdatedAt = now
record.ID = bson.NilObjectID
err := t.repo.Insert(ctx, record, repository.Filter(fieldTreasuryRequestID, record.RequestID)) err := t.repo.Insert(ctx, record, repository.Filter(fieldTreasuryRequestID, record.RequestID))
if errors.Is(err, merrors.ErrDataConflict) { if errors.Is(err, merrors.ErrDataConflict) {
return storage.ErrDuplicate return storage.ErrDuplicate
} }
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
t.logger.Warn("Failed to create treasury request", zap.Error(err), zap.String("request_id", record.RequestID)) t.logger.Warn("Failed to create treasury request", zap.Error(err), zap.String("request_id", record.RequestID))
return err
} }
t.logger.Info("Treasury request created",
zap.String("request_id", record.RequestID),
zap.String("telegram_user_id", record.TelegramUserID),
zap.String("chat_id", record.ChatID),
zap.String("ledger_account_id", record.LedgerAccountID),
zap.String("ledger_account_code", record.LedgerAccountCode),
zap.String("operation_type", strings.TrimSpace(string(record.OperationType))),
zap.String("status", strings.TrimSpace(string(record.Status))),
zap.String("amount", record.Amount),
zap.String("currency", record.Currency))
return err return err
} }
@@ -159,11 +164,17 @@ func (t *TreasuryRequests) FindByRequestID(ctx context.Context, requestID string
var result model.TreasuryRequest var result model.TreasuryRequest
err := t.repo.FindOneByFilter(ctx, repository.Filter(fieldTreasuryRequestID, requestID), &result) err := t.repo.FindOneByFilter(ctx, repository.Filter(fieldTreasuryRequestID, requestID), &result)
if errors.Is(err, merrors.ErrNoData) { if errors.Is(err, merrors.ErrNoData) {
t.logger.Debug("Treasury request not found", zap.String("request_id", requestID))
return nil, nil return nil, nil
} }
if err != nil { if err != nil {
t.logger.Warn("Failed to load treasury request", zap.Error(err), zap.String("request_id", requestID))
return nil, err return nil, err
} }
t.logger.Debug("Treasury request loaded",
zap.String("request_id", requestID),
zap.String("status", strings.TrimSpace(string(result.Status))),
zap.String("ledger_account_id", strings.TrimSpace(result.LedgerAccountID)))
return &result, nil return &result, nil
} }
@@ -178,15 +189,21 @@ func (t *TreasuryRequests) FindActiveByLedgerAccountID(ctx context.Context, ledg
Filter(repository.Field(fieldTreasuryActive), true) Filter(repository.Field(fieldTreasuryActive), true)
err := t.repo.FindOneByFilter(ctx, query, &result) err := t.repo.FindOneByFilter(ctx, query, &result)
if errors.Is(err, merrors.ErrNoData) { if errors.Is(err, merrors.ErrNoData) {
t.logger.Debug("Active treasury request not found", zap.String("ledger_account_id", ledgerAccountID))
return nil, nil return nil, nil
} }
if err != nil { if err != nil {
t.logger.Warn("Failed to load active treasury request", zap.Error(err), zap.String("ledger_account_id", ledgerAccountID))
return nil, err return nil, err
} }
t.logger.Debug("Active treasury request loaded",
zap.String("request_id", strings.TrimSpace(result.RequestID)),
zap.String("ledger_account_id", ledgerAccountID),
zap.String("status", strings.TrimSpace(string(result.Status))))
return &result, nil return &result, nil
} }
func (t *TreasuryRequests) FindDueByStatus(ctx context.Context, statuses []model.TreasuryRequestStatus, now time.Time, limit int64) ([]*model.TreasuryRequest, error) { func (t *TreasuryRequests) FindDueByStatus(ctx context.Context, statuses []model.TreasuryRequestStatus, now time.Time, limit int64) ([]model.TreasuryRequest, error) {
if len(statuses) == 0 { if len(statuses) == 0 {
return nil, nil return nil, nil
} }
@@ -210,18 +227,20 @@ func (t *TreasuryRequests) FindDueByStatus(ctx context.Context, statuses []model
Sort(repository.Field(fieldTreasuryScheduledAt), true). Sort(repository.Field(fieldTreasuryScheduledAt), true).
Limit(&limit) Limit(&limit)
result := make([]*model.TreasuryRequest, 0) result, err := mutil.GetObjects[model.TreasuryRequest](ctx, t.logger, query, nil, t.repo)
err := t.repo.FindManyByFilter(ctx, query, func(cur *mongo.Cursor) error {
next := &model.TreasuryRequest{}
if err := cur.Decode(next); err != nil {
return err
}
result = append(result, next)
return nil
})
if err != nil && !errors.Is(err, merrors.ErrNoData) { if err != nil && !errors.Is(err, merrors.ErrNoData) {
t.logger.Warn("Failed to list due treasury requests",
zap.Error(err),
zap.Any("statuses", statusValues),
zap.Time("scheduled_before", now),
zap.Int64("limit", limit))
return nil, err return nil, err
} }
t.logger.Debug("Due treasury requests loaded",
zap.Any("statuses", statusValues),
zap.Time("scheduled_before", now),
zap.Int64("limit", limit),
zap.Int("count", len(result)))
return result, nil return result, nil
} }
@@ -231,14 +250,19 @@ func (t *TreasuryRequests) ClaimScheduled(ctx context.Context, requestID string)
return false, merrors.InvalidArgument("request_id is required", "request_id") return false, merrors.InvalidArgument("request_id is required", "request_id")
} }
patch := repository.Patch(). patch := repository.Patch().
Set(repository.Field(fieldTreasuryStatus), string(model.TreasuryRequestStatusConfirmed)). Set(repository.Field(fieldTreasuryStatus), string(model.TreasuryRequestStatusConfirmed))
Set(repository.Field("updatedAt"), time.Now())
updated, err := t.repo.PatchMany(ctx, repository.Filter(fieldTreasuryRequestID, requestID).And( updated, err := t.repo.PatchMany(ctx, repository.Filter(fieldTreasuryRequestID, requestID).And(
repository.Filter(fieldTreasuryStatus, string(model.TreasuryRequestStatusScheduled)), repository.Filter(fieldTreasuryStatus, string(model.TreasuryRequestStatusScheduled)),
), patch) ), patch)
if err != nil { if err != nil {
t.logger.Warn("Failed to claim scheduled treasury request", zap.Error(err), zap.String("request_id", requestID))
return false, err return false, err
} }
if updated > 0 {
t.logger.Info("Scheduled treasury request claimed", zap.String("request_id", requestID))
} else {
t.logger.Debug("Scheduled treasury request claim skipped", zap.String("request_id", requestID))
}
return updated > 0, nil return updated > 0, nil
} }
@@ -247,6 +271,16 @@ func (t *TreasuryRequests) Update(ctx context.Context, record *model.TreasuryReq
return merrors.InvalidArgument("treasury request is nil", "record") return merrors.InvalidArgument("treasury request is nil", "record")
} }
record.RequestID = strings.TrimSpace(record.RequestID) record.RequestID = strings.TrimSpace(record.RequestID)
record.TelegramUserID = strings.TrimSpace(record.TelegramUserID)
record.LedgerAccountID = strings.TrimSpace(record.LedgerAccountID)
record.LedgerAccountCode = strings.TrimSpace(record.LedgerAccountCode)
record.OrganizationRef = strings.TrimSpace(record.OrganizationRef)
record.ChatID = strings.TrimSpace(record.ChatID)
record.Amount = strings.TrimSpace(record.Amount)
record.Currency = strings.ToUpper(strings.TrimSpace(record.Currency))
record.IdempotencyKey = strings.TrimSpace(record.IdempotencyKey)
record.LedgerReference = strings.TrimSpace(record.LedgerReference)
record.ErrorMessage = strings.TrimSpace(record.ErrorMessage)
if record.RequestID == "" { if record.RequestID == "" {
return merrors.InvalidArgument("request_id is required", "request_id") return merrors.InvalidArgument("request_id is required", "request_id")
} }
@@ -257,21 +291,46 @@ func (t *TreasuryRequests) Update(ctx context.Context, record *model.TreasuryReq
if existing == nil { if existing == nil {
return merrors.NoData("treasury request not found") return merrors.NoData("treasury request not found")
} }
record.ID = existing.ID
if record.CreatedAt.IsZero() { patch := repository.Patch().
record.CreatedAt = existing.CreatedAt Set(repository.Field("operationType"), record.OperationType).
} Set(repository.Field("telegramUserId"), record.TelegramUserID).
record.UpdatedAt = time.Now() Set(repository.Field("ledgerAccountId"), record.LedgerAccountID).
if err := t.repo.Update(ctx, record); err != nil { Set(repository.Field("ledgerAccountCode"), record.LedgerAccountCode).
Set(repository.Field("organizationRef"), record.OrganizationRef).
Set(repository.Field("chatId"), record.ChatID).
Set(repository.Field("amount"), record.Amount).
Set(repository.Field("currency"), record.Currency).
Set(repository.Field(fieldTreasuryStatus), record.Status).
Set(repository.Field("confirmedAt"), record.ConfirmedAt).
Set(repository.Field("scheduledAt"), record.ScheduledAt).
Set(repository.Field("executedAt"), record.ExecutedAt).
Set(repository.Field("cancelledAt"), record.CancelledAt).
Set(repository.Field(fieldTreasuryIdempotencyKey), record.IdempotencyKey).
Set(repository.Field("ledgerReference"), record.LedgerReference).
Set(repository.Field("errorMessage"), record.ErrorMessage).
Set(repository.Field(fieldTreasuryActive), record.Active)
if _, err := t.repo.PatchMany(ctx, repository.Filter(fieldTreasuryRequestID, record.RequestID), patch); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
t.logger.Warn("Failed to update treasury request", zap.Error(err), zap.String("request_id", record.RequestID)) t.logger.Warn("Failed to update treasury request", zap.Error(err), zap.String("request_id", record.RequestID))
} }
return err return err
} }
t.logger.Info("Treasury request updated",
zap.String("request_id", record.RequestID),
zap.String("telegram_user_id", strings.TrimSpace(record.TelegramUserID)),
zap.String("chat_id", strings.TrimSpace(record.ChatID)),
zap.String("ledger_account_id", strings.TrimSpace(record.LedgerAccountID)),
zap.String("ledger_account_code", strings.TrimSpace(record.LedgerAccountCode)),
zap.String("operation_type", strings.TrimSpace(string(record.OperationType))),
zap.String("status", strings.TrimSpace(string(record.Status))),
zap.String("amount", strings.TrimSpace(record.Amount)),
zap.String("currency", strings.TrimSpace(record.Currency)),
zap.String("error_message", strings.TrimSpace(record.ErrorMessage)))
return nil return nil
} }
func (t *TreasuryRequests) ListByAccountAndStatuses(ctx context.Context, ledgerAccountID string, statuses []model.TreasuryRequestStatus, dayStart, dayEnd time.Time) ([]*model.TreasuryRequest, error) { func (t *TreasuryRequests) ListByAccountAndStatuses(ctx context.Context, ledgerAccountID string, statuses []model.TreasuryRequestStatus, dayStart, dayEnd time.Time) ([]model.TreasuryRequest, error) {
ledgerAccountID = strings.TrimSpace(ledgerAccountID) ledgerAccountID = strings.TrimSpace(ledgerAccountID)
if ledgerAccountID == "" { if ledgerAccountID == "" {
return nil, merrors.InvalidArgument("ledger_account_id is required", "ledger_account_id") return nil, merrors.InvalidArgument("ledger_account_id is required", "ledger_account_id")
@@ -293,18 +352,22 @@ func (t *TreasuryRequests) ListByAccountAndStatuses(ctx context.Context, ledgerA
Comparison(repository.Field(fieldTreasuryCreatedAt), builder.Gte, dayStart). Comparison(repository.Field(fieldTreasuryCreatedAt), builder.Gte, dayStart).
Comparison(repository.Field(fieldTreasuryCreatedAt), builder.Lt, dayEnd) Comparison(repository.Field(fieldTreasuryCreatedAt), builder.Lt, dayEnd)
result := make([]*model.TreasuryRequest, 0) result, err := mutil.GetObjects[model.TreasuryRequest](ctx, t.logger, query, nil, t.repo)
err := t.repo.FindManyByFilter(ctx, query, func(cur *mongo.Cursor) error {
next := &model.TreasuryRequest{}
if err := cur.Decode(next); err != nil {
return err
}
result = append(result, next)
return nil
})
if err != nil && !errors.Is(err, merrors.ErrNoData) { if err != nil && !errors.Is(err, merrors.ErrNoData) {
t.logger.Warn("Failed to list treasury requests by account and statuses",
zap.Error(err),
zap.String("ledger_account_id", ledgerAccountID),
zap.Any("statuses", statusValues),
zap.Time("day_start", dayStart),
zap.Time("day_end", dayEnd))
return nil, err return nil, err
} }
t.logger.Debug("Treasury requests loaded by account and statuses",
zap.String("ledger_account_id", ledgerAccountID),
zap.Any("statuses", statusValues),
zap.Time("day_start", dayStart),
zap.Time("day_end", dayEnd),
zap.Int("count", len(result)))
return result, nil return result, nil
} }

View File

@@ -34,15 +34,15 @@ type PendingConfirmationsStore interface {
MarkClarified(ctx context.Context, requestID string) error MarkClarified(ctx context.Context, requestID string) error
AttachMessage(ctx context.Context, requestID string, messageID string) error AttachMessage(ctx context.Context, requestID string, messageID string) error
DeleteByRequestID(ctx context.Context, requestID string) error DeleteByRequestID(ctx context.Context, requestID string) error
ListExpired(ctx context.Context, now time.Time, limit int64) ([]*model.PendingConfirmation, error) ListExpired(ctx context.Context, now time.Time, limit int64) ([]model.PendingConfirmation, error)
} }
type TreasuryRequestsStore interface { type TreasuryRequestsStore interface {
Create(ctx context.Context, record *model.TreasuryRequest) error Create(ctx context.Context, record *model.TreasuryRequest) error
FindByRequestID(ctx context.Context, requestID string) (*model.TreasuryRequest, error) FindByRequestID(ctx context.Context, requestID string) (*model.TreasuryRequest, error)
FindActiveByLedgerAccountID(ctx context.Context, ledgerAccountID string) (*model.TreasuryRequest, error) FindActiveByLedgerAccountID(ctx context.Context, ledgerAccountID string) (*model.TreasuryRequest, error)
FindDueByStatus(ctx context.Context, statuses []model.TreasuryRequestStatus, now time.Time, limit int64) ([]*model.TreasuryRequest, error) FindDueByStatus(ctx context.Context, statuses []model.TreasuryRequestStatus, now time.Time, limit int64) ([]model.TreasuryRequest, error)
ClaimScheduled(ctx context.Context, requestID string) (bool, error) ClaimScheduled(ctx context.Context, requestID string) (bool, error)
Update(ctx context.Context, record *model.TreasuryRequest) error Update(ctx context.Context, record *model.TreasuryRequest) error
ListByAccountAndStatuses(ctx context.Context, ledgerAccountID string, statuses []model.TreasuryRequestStatus, dayStart, dayEnd time.Time) ([]*model.TreasuryRequest, error) ListByAccountAndStatuses(ctx context.Context, ledgerAccountID string, statuses []model.TreasuryRequestStatus, dayStart, dayEnd time.Time) ([]model.TreasuryRequest, error)
} }