Files
sendico/api/gateway/tgsettle/internal/service/treasury/scheduler.go
2026-03-05 13:24:41 +01:00

328 lines
8.9 KiB
Go

package treasury
import (
"context"
"strings"
"sync"
"time"
storagemodel "github.com/tech/sendico/gateway/tgsettle/storage/model"
"github.com/tech/sendico/pkg/mlogger"
"go.uber.org/zap"
)
type NotifyFunc func(ctx context.Context, chatID string, text string) error
type Scheduler struct {
logger mlogger.Logger
service *Service
notify NotifyFunc
safetySweepInterval time.Duration
cancel context.CancelFunc
wg sync.WaitGroup
timersMu sync.Mutex
timers map[string]*time.Timer
}
func NewScheduler(logger mlogger.Logger, service *Service, notify NotifyFunc, safetySweepInterval time.Duration) *Scheduler {
if logger != nil {
logger = logger.Named("treasury_scheduler")
}
if safetySweepInterval <= 0 {
safetySweepInterval = 30 * time.Second
}
return &Scheduler{
logger: logger,
service: service,
notify: notify,
safetySweepInterval: safetySweepInterval,
timers: map[string]*time.Timer{},
}
}
func (s *Scheduler) Start() {
if s == nil || s.service == nil || s.cancel != nil {
return
}
ctx, cancel := context.WithCancel(context.Background())
s.cancel = cancel
// Rebuild in-memory timers from DB on startup.
s.hydrateTimers(ctx)
// Safety pass for overdue items at startup.
s.sweep(ctx)
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.safetySweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.sweep(ctx)
}
}
}()
}
func (s *Scheduler) Shutdown() {
if s == nil || s.cancel == nil {
return
}
s.cancel()
s.wg.Wait()
s.timersMu.Lock()
for requestID, timer := range s.timers {
if timer != nil {
timer.Stop()
}
delete(s.timers, requestID)
}
s.timersMu.Unlock()
}
func (s *Scheduler) TrackScheduled(record *storagemodel.TreasuryRequest) {
if s == nil || s.service == nil || record == nil {
return
}
if strings.TrimSpace(record.RequestID) == "" {
return
}
if record.Status != storagemodel.TreasuryRequestStatusScheduled {
return
}
requestID := strings.TrimSpace(record.RequestID)
when := record.ScheduledAt
if when.IsZero() {
when = time.Now()
}
delay := time.Until(when)
if delay <= 0 {
s.Untrack(requestID)
go s.executeAndNotifyByID(context.Background(), requestID)
return
}
s.timersMu.Lock()
if existing := s.timers[requestID]; existing != nil {
existing.Stop()
}
s.timers[requestID] = time.AfterFunc(delay, func() {
s.Untrack(requestID)
s.executeAndNotifyByID(context.Background(), requestID)
})
s.timersMu.Unlock()
}
func (s *Scheduler) Untrack(requestID string) {
if s == nil {
return
}
requestID = strings.TrimSpace(requestID)
if requestID == "" {
return
}
s.timersMu.Lock()
if timer := s.timers[requestID]; timer != nil {
timer.Stop()
}
delete(s.timers, requestID)
s.timersMu.Unlock()
}
func (s *Scheduler) hydrateTimers(ctx context.Context) {
if s == nil || s.service == nil {
return
}
scheduled, err := s.service.ScheduledRequests(ctx, 1000)
if err != nil {
s.logger.Warn("Failed to hydrate scheduled treasury requests", zap.Error(err))
return
}
for _, record := range scheduled {
s.TrackScheduled(&record)
}
}
func (s *Scheduler) sweep(ctx context.Context) {
if s == nil || s.service == nil {
return
}
now := time.Now()
confirmed, err := s.service.DueRequests(ctx, []storagemodel.TreasuryRequestStatus{
storagemodel.TreasuryRequestStatusConfirmed,
}, now, 100)
if err != nil {
s.logger.Warn("Failed to list confirmed treasury requests", zap.Error(err))
return
}
for _, request := range confirmed {
s.executeAndNotifyByID(ctx, strings.TrimSpace(request.RequestID))
}
scheduled, err := s.service.DueRequests(ctx, []storagemodel.TreasuryRequestStatus{
storagemodel.TreasuryRequestStatusScheduled,
}, now, 100)
if err != nil {
s.logger.Warn("Failed to list scheduled treasury requests", zap.Error(err))
return
}
for _, request := range scheduled {
s.Untrack(strings.TrimSpace(request.RequestID))
s.executeAndNotifyByID(ctx, strings.TrimSpace(request.RequestID))
}
}
func (s *Scheduler) executeAndNotifyByID(ctx context.Context, requestID string) {
if s == nil || s.service == nil {
return
}
requestID = strings.TrimSpace(requestID)
if requestID == "" {
return
}
runCtx := ctx
if runCtx == nil {
runCtx = context.Background()
}
withTimeout, cancel := context.WithTimeout(runCtx, 30*time.Second)
defer cancel()
result, err := s.service.ExecuteRequest(withTimeout, requestID)
if err != nil {
s.logger.Warn("Failed to execute treasury request", zap.Error(err), zap.String("request_id", requestID))
return
}
if result == nil || result.Request == nil {
s.logger.Debug("Treasury execution produced no result", zap.String("request_id", requestID))
return
}
if s.notify == nil {
s.logger.Warn("Treasury execution notifier is unavailable", zap.String("request_id", requestID))
return
}
text := executionMessage(result)
if strings.TrimSpace(text) == "" {
s.logger.Debug("Treasury execution result has no notification text",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
return
}
chatID := strings.TrimSpace(result.Request.ChatID)
if chatID == "" {
s.logger.Warn("Treasury execution notification skipped: empty chat_id",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)))
return
}
s.logger.Info("Sending treasury execution notification",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
notifyCtx := context.Background()
if ctx != nil {
notifyCtx = ctx
}
notifyCtx, notifyCancel := context.WithTimeout(notifyCtx, 15*time.Second)
defer notifyCancel()
if err := s.notify(notifyCtx, chatID, text); err != nil {
s.logger.Warn("Failed to notify treasury execution result",
zap.Error(err),
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
return
}
s.logger.Info("Treasury execution notification sent",
zap.String("request_id", strings.TrimSpace(result.Request.RequestID)),
zap.String("chat_id", chatID),
zap.String("status", strings.TrimSpace(string(result.Request.Status))))
}
func executionMessage(result *ExecutionResult) string {
if result == nil || result.Request == nil {
return ""
}
request := result.Request
switch request.Status {
case storagemodel.TreasuryRequestStatusExecuted:
op := "Funding"
sign := "+"
if request.OperationType == storagemodel.TreasuryOperationWithdraw {
op = "Withdrawal"
sign = "-"
}
balanceAmount := "unavailable"
balanceCurrency := strings.TrimSpace(request.Currency)
if result.NewBalance != nil {
if strings.TrimSpace(result.NewBalance.Amount) != "" {
balanceAmount = strings.TrimSpace(result.NewBalance.Amount)
}
if strings.TrimSpace(result.NewBalance.Currency) != "" {
balanceCurrency = strings.TrimSpace(result.NewBalance.Currency)
}
}
return "*" + op + " completed*\n\n" +
"*Account:* " + markdownCode(requestAccountCode(request)) + "\n" +
"*Amount:* " + markdownCode(sign+strings.TrimSpace(request.Amount)+" "+strings.TrimSpace(request.Currency)) + "\n" +
"*New balance:* " + markdownCode(balanceAmount+" "+balanceCurrency) + "\n\n" +
"*Reference:* " + markdownCode(strings.TrimSpace(request.RequestID))
case storagemodel.TreasuryRequestStatusFailed:
reason := strings.TrimSpace(request.ErrorMessage)
if reason == "" && result.ExecutionError != nil {
reason = strings.TrimSpace(result.ExecutionError.Error())
}
if reason == "" {
reason = "Unknown error."
}
return "*Execution failed*\n\n" +
"*Account:* " + markdownCode(requestAccountCode(request)) + "\n" +
"*Amount:* " + markdownCode(strings.TrimSpace(request.Amount)+" "+strings.TrimSpace(request.Currency)) + "\n" +
"*Status:* " + markdownCode("FAILED") + "\n" +
"*Reason:* " + markdownCode(compactForMarkdown(reason)) + "\n\n" +
"*Request ID:* " + markdownCode(strings.TrimSpace(request.RequestID))
default:
return ""
}
}
func requestAccountCode(request *storagemodel.TreasuryRequest) string {
if request == nil {
return ""
}
if code := strings.TrimSpace(request.LedgerAccountCode); code != "" {
return code
}
return strings.TrimSpace(request.LedgerAccountID)
}
func markdownCode(value string) string {
value = strings.TrimSpace(value)
if value == "" {
value = "N/A"
}
value = strings.ReplaceAll(value, "`", "'")
return "`" + value + "`"
}
func compactForMarkdown(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return "Unknown error."
}
value = strings.ReplaceAll(value, "\r\n", " ")
value = strings.ReplaceAll(value, "\n", " ")
value = strings.ReplaceAll(value, "\r", " ")
return strings.Join(strings.Fields(value), " ")
}