package plan import ( "context" "fmt" "sort" "strings" "github.com/shopspring/decimal" "github.com/tech/sendico/payments/storage/model" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/mlogger" paymenttypes "github.com/tech/sendico/pkg/payments/types" "go.uber.org/zap" ) func ensureGatewayForAction(ctx context.Context, logger mlogger.Logger, registry GatewayRegistry, cache map[model.Rail]*model.GatewayInstanceDescriptor, rail model.Rail, network string, amount *paymenttypes.Money, action model.RailOperation, instanceID string, dir sendDirection) (*model.GatewayInstanceDescriptor, error) { if registry == nil { return nil, merrors.InvalidArgument("plan builder: gateway registry is required") } if gw, ok := cache[rail]; ok && gw != nil { if instanceID == "" || strings.EqualFold(gw.InstanceID, instanceID) { if err := validateGatewayAction(gw, network, amount, action, dir); err != nil { logger.Warn("Failed to validate gateway", zap.Error(err), zap.String("instance_id", instanceID), zap.String("rail", string(rail)), zap.String("network", network), zap.String("action", string(action)), zap.String("direction", sendDirectionLabel(dir)), zap.Int("rails_qty", len(cache)), ) return nil, err } return gw, nil } } gw, err := selectGateway(ctx, registry, rail, network, amount, action, instanceID, dir) if err != nil { logger.Warn("Failed to select gateway", zap.Error(err), zap.String("instance_id", instanceID), zap.String("rail", string(rail)), zap.String("network", network), zap.String("action", string(action)), zap.String("direction", sendDirectionLabel(dir)), zap.Int("rails_qty", len(cache)), ) return nil, err } cache[rail] = gw return gw, nil } func validateGatewayAction(gw *model.GatewayInstanceDescriptor, network string, amount *paymenttypes.Money, action model.RailOperation, dir sendDirection) error { if gw == nil { return merrors.InvalidArgument("plan builder: gateway instance is required") } currency := "" amt := decimal.Zero if amount != nil && strings.TrimSpace(amount.GetAmount()) != "" { value, err := decimalFromMoney(amount) if err != nil { return err } amt = value currency = strings.ToUpper(strings.TrimSpace(amount.GetCurrency())) } if err := isGatewayEligible(gw, gw.Rail, network, currency, action, dir, amt); err != nil { return merrors.InvalidArgument("plan builder: gateway instance is not eligible: " + err.Error()) } return nil } type sendDirection int const ( sendDirectionAny sendDirection = iota sendDirectionOut sendDirectionIn ) func sendDirectionForRail(rail model.Rail) sendDirection { switch rail { case model.RailFiatOnRamp: return sendDirectionIn default: return sendDirectionOut } } func selectGateway(ctx context.Context, registry GatewayRegistry, rail model.Rail, network string, amount *paymenttypes.Money, action model.RailOperation, instanceID string, dir sendDirection) (*model.GatewayInstanceDescriptor, error) { if registry == nil { return nil, merrors.InvalidArgument("plan builder: gateway registry is required") } all, err := registry.List(ctx) if err != nil { return nil, err } if len(all) == 0 { return nil, merrors.InvalidArgument("plan builder: no gateway instances available") } currency := "" amt := decimal.Zero if amount != nil && strings.TrimSpace(amount.GetAmount()) != "" { amt, err = decimalFromMoney(amount) if err != nil { return nil, err } currency = strings.ToUpper(strings.TrimSpace(amount.GetCurrency())) } network = strings.ToUpper(strings.TrimSpace(network)) eligible := make([]*model.GatewayInstanceDescriptor, 0) var lastErr error for _, gw := range all { if instanceID != "" && !strings.EqualFold(strings.TrimSpace(gw.InstanceID), instanceID) { continue } if err := isGatewayEligible(gw, rail, network, currency, action, dir, amt); err != nil { lastErr = err continue } eligible = append(eligible, gw) } if len(eligible) == 0 { if lastErr != nil { return nil, merrors.InvalidArgument("plan builder: no eligible gateway instance found, last error: " + lastErr.Error()) } return nil, merrors.InvalidArgument("plan builder: no eligible gateway instance found") } sort.Slice(eligible, func(i, j int) bool { return eligible[i].ID < eligible[j].ID }) return eligible[0], nil } type gatewayIneligibleError struct { reason string } func (e gatewayIneligibleError) Error() string { return e.reason } func gatewayIneligible(gw *model.GatewayInstanceDescriptor, reason string) error { if strings.TrimSpace(reason) == "" { reason = "gateway instance is not eligible" } return gatewayIneligibleError{reason: fmt.Sprintf("gateway %s eligibility check error: %s", gw.InstanceID, reason)} } func sendDirectionLabel(dir sendDirection) string { switch dir { case sendDirectionOut: return "out" case sendDirectionIn: return "in" default: return "any" } } func isGatewayEligible(gw *model.GatewayInstanceDescriptor, rail model.Rail, network, currency string, action model.RailOperation, dir sendDirection, amount decimal.Decimal) error { if gw == nil { return gatewayIneligible(gw, "gateway instance is required") } if !gw.IsEnabled { return gatewayIneligible(gw, "gateway instance is disabled") } if gw.Rail != rail { return gatewayIneligible(gw, fmt.Sprintf("rail mismatch: want %s got %s", rail, gw.Rail)) } if network != "" && gw.Network != "" && !strings.EqualFold(gw.Network, network) { return gatewayIneligible(gw, fmt.Sprintf("network mismatch: want %s got %s", network, gw.Network)) } if currency != "" && len(gw.Currencies) > 0 { found := false for _, c := range gw.Currencies { if strings.EqualFold(c, currency) { found = true break } } if !found { return gatewayIneligible(gw, "currency not supported: "+currency) } } if !capabilityAllowsAction(gw.Capabilities, action, dir) { return gatewayIneligible(gw, fmt.Sprintf("capability does not allow action=%s dir=%s", action, sendDirectionLabel(dir))) } if currency != "" { if err := amountWithinLimits(gw, gw.Limits, currency, amount, action); err != nil { return err } } return nil } func capabilityAllowsAction(cap model.RailCapabilities, action model.RailOperation, dir sendDirection) bool { switch action { case model.RailOperationSend: switch dir { case sendDirectionOut: return cap.CanPayOut case sendDirectionIn: return cap.CanPayIn default: return cap.CanPayIn || cap.CanPayOut } case model.RailOperationFee: return cap.CanSendFee case model.RailOperationObserveConfirm: return cap.RequiresObserveConfirm case model.RailOperationBlock: return cap.CanBlock case model.RailOperationRelease: return cap.CanRelease default: return true } } func amountWithinLimits(gw *model.GatewayInstanceDescriptor, limits model.Limits, currency string, amount decimal.Decimal, action model.RailOperation) error { min := firstLimitValue(limits.MinAmount, "") max := firstLimitValue(limits.MaxAmount, "") perTxMin := firstLimitValue(limits.PerTxMinAmount, "") perTxMax := firstLimitValue(limits.PerTxMaxAmount, "") maxFee := firstLimitValue(limits.PerTxMaxFee, "") if override, ok := limits.CurrencyLimits[currency]; ok { min = firstLimitValue(override.MinAmount, min) max = firstLimitValue(override.MaxAmount, max) if action == model.RailOperationFee { maxFee = firstLimitValue(override.MaxFee, maxFee) } } if min != "" { if val, err := decimal.NewFromString(min); err == nil && amount.LessThan(val) { return gatewayIneligible(gw, fmt.Sprintf("amount %s %s below min limit %s", amount.String(), currency, val.String())) } } if perTxMin != "" { if val, err := decimal.NewFromString(perTxMin); err == nil && amount.LessThan(val) { return gatewayIneligible(gw, fmt.Sprintf("amount %s %s below per-tx min limit %s", amount.String(), currency, val.String())) } } if max != "" { if val, err := decimal.NewFromString(max); err == nil && amount.GreaterThan(val) { return gatewayIneligible(gw, fmt.Sprintf("amount %s %s exceeds max limit %s", amount.String(), currency, val.String())) } } if perTxMax != "" { if val, err := decimal.NewFromString(perTxMax); err == nil && amount.GreaterThan(val) { return gatewayIneligible(gw, fmt.Sprintf("amount %s %s exceeds per-tx max limit %s", amount.String(), currency, val.String())) } } if action == model.RailOperationFee && maxFee != "" { if val, err := decimal.NewFromString(maxFee); err == nil && amount.GreaterThan(val) { return gatewayIneligible(gw, fmt.Sprintf("fee amount %s %s exceeds max fee limit %s", amount.String(), currency, val.String())) } } return nil } func firstLimitValue(primary, fallback string) string { val := strings.TrimSpace(primary) if val != "" { return val } return strings.TrimSpace(fallback) }