package treasury import ( "context" "strings" "sync" "time" storagemodel "github.com/tech/sendico/gateway/tgsettle/storage/model" "github.com/tech/sendico/pkg/mlogger" "go.uber.org/zap" ) type NotifyFunc func(ctx context.Context, chatID string, text string) error type Scheduler struct { logger mlogger.Logger service *Service notify NotifyFunc safetySweepInterval time.Duration cancel context.CancelFunc wg sync.WaitGroup timersMu sync.Mutex timers map[string]*time.Timer } func NewScheduler(logger mlogger.Logger, service *Service, notify NotifyFunc, safetySweepInterval time.Duration) *Scheduler { if logger != nil { logger = logger.Named("treasury_scheduler") } if safetySweepInterval <= 0 { safetySweepInterval = 30 * time.Second } return &Scheduler{ logger: logger, service: service, notify: notify, safetySweepInterval: safetySweepInterval, timers: map[string]*time.Timer{}, } } func (s *Scheduler) Start() { if s == nil || s.service == nil || s.cancel != nil { return } ctx, cancel := context.WithCancel(context.Background()) s.cancel = cancel // Rebuild in-memory timers from DB on startup. s.hydrateTimers(ctx) // Safety pass for overdue items at startup. s.sweep(ctx) s.wg.Add(1) go func() { defer s.wg.Done() ticker := time.NewTicker(s.safetySweepInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: s.sweep(ctx) } } }() } func (s *Scheduler) Shutdown() { if s == nil || s.cancel == nil { return } s.cancel() s.wg.Wait() s.timersMu.Lock() for requestID, timer := range s.timers { if timer != nil { timer.Stop() } delete(s.timers, requestID) } s.timersMu.Unlock() } func (s *Scheduler) TrackScheduled(record *storagemodel.TreasuryRequest) { if s == nil || s.service == nil || record == nil { return } if strings.TrimSpace(record.RequestID) == "" { return } if record.Status != storagemodel.TreasuryRequestStatusScheduled { return } requestID := strings.TrimSpace(record.RequestID) when := record.ScheduledAt if when.IsZero() { when = time.Now() } delay := time.Until(when) if delay <= 0 { s.Untrack(requestID) go s.executeAndNotifyByID(context.Background(), requestID) return } s.timersMu.Lock() if existing := s.timers[requestID]; existing != nil { existing.Stop() } s.timers[requestID] = time.AfterFunc(delay, func() { s.Untrack(requestID) s.executeAndNotifyByID(context.Background(), requestID) }) s.timersMu.Unlock() } func (s *Scheduler) Untrack(requestID string) { if s == nil { return } requestID = strings.TrimSpace(requestID) if requestID == "" { return } s.timersMu.Lock() if timer := s.timers[requestID]; timer != nil { timer.Stop() } delete(s.timers, requestID) s.timersMu.Unlock() } func (s *Scheduler) hydrateTimers(ctx context.Context) { if s == nil || s.service == nil { return } scheduled, err := s.service.ScheduledRequests(ctx, 1000) if err != nil { s.logger.Warn("Failed to hydrate scheduled treasury requests", zap.Error(err)) return } for _, record := range scheduled { s.TrackScheduled(record) } } func (s *Scheduler) sweep(ctx context.Context) { if s == nil || s.service == nil { return } now := time.Now() confirmed, err := s.service.DueRequests(ctx, []storagemodel.TreasuryRequestStatus{ storagemodel.TreasuryRequestStatusConfirmed, }, now, 100) if err != nil { s.logger.Warn("Failed to list confirmed treasury requests", zap.Error(err)) return } for _, request := range confirmed { s.executeAndNotifyByID(ctx, strings.TrimSpace(request.RequestID)) } scheduled, err := s.service.DueRequests(ctx, []storagemodel.TreasuryRequestStatus{ storagemodel.TreasuryRequestStatusScheduled, }, now, 100) if err != nil { s.logger.Warn("Failed to list scheduled treasury requests", zap.Error(err)) return } for _, request := range scheduled { s.Untrack(strings.TrimSpace(request.RequestID)) s.executeAndNotifyByID(ctx, strings.TrimSpace(request.RequestID)) } } func (s *Scheduler) executeAndNotifyByID(ctx context.Context, requestID string) { if s == nil || s.service == nil { return } requestID = strings.TrimSpace(requestID) if requestID == "" { return } runCtx := ctx if runCtx == nil { runCtx = context.Background() } withTimeout, cancel := context.WithTimeout(runCtx, 30*time.Second) defer cancel() result, err := s.service.ExecuteRequest(withTimeout, requestID) if err != nil { s.logger.Warn("Failed to execute treasury request", zap.Error(err), zap.String("request_id", requestID)) return } if result == nil || result.Request == nil || s.notify == nil { return } text := executionMessage(result) if strings.TrimSpace(text) == "" { return } if err := s.notify(ctx, strings.TrimSpace(result.Request.ChatID), text); err != nil { s.logger.Warn("Failed to notify treasury execution result", zap.Error(err), zap.String("request_id", strings.TrimSpace(result.Request.RequestID))) } } func executionMessage(result *ExecutionResult) string { if result == nil || result.Request == nil { return "" } request := result.Request switch request.Status { case storagemodel.TreasuryRequestStatusExecuted: op := "Funding" sign := "+" if request.OperationType == storagemodel.TreasuryOperationWithdraw { op = "Withdrawal" sign = "-" } balanceAmount := "unavailable" balanceCurrency := strings.TrimSpace(request.Currency) if result.NewBalance != nil { if strings.TrimSpace(result.NewBalance.Amount) != "" { balanceAmount = strings.TrimSpace(result.NewBalance.Amount) } if strings.TrimSpace(result.NewBalance.Currency) != "" { balanceCurrency = strings.TrimSpace(result.NewBalance.Currency) } } return op + " completed.\n\n" + "Account: " + strings.TrimSpace(request.LedgerAccountID) + "\n" + "Amount: " + sign + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" + "New balance: " + balanceAmount + " " + balanceCurrency + "\n\n" + "Reference: " + strings.TrimSpace(request.RequestID) case storagemodel.TreasuryRequestStatusFailed: reason := strings.TrimSpace(request.ErrorMessage) if reason == "" && result.ExecutionError != nil { reason = strings.TrimSpace(result.ExecutionError.Error()) } if reason == "" { reason = "Unknown error." } return "Execution failed.\n\n" + "Account: " + strings.TrimSpace(request.LedgerAccountID) + "\n" + "Amount: " + strings.TrimSpace(request.Amount) + " " + strings.TrimSpace(request.Currency) + "\n" + "Status: FAILED\n\n" + "Reason:\n" + reason + "\n\n" + "Request ID: " + strings.TrimSpace(request.RequestID) default: return "" } }