package store import ( "context" "errors" "strings" "time" "github.com/tech/sendico/gateway/chain/storage" "github.com/tech/sendico/gateway/chain/storage/model" "github.com/tech/sendico/pkg/db/repository" "github.com/tech/sendico/pkg/db/repository/builder" ri "github.com/tech/sendico/pkg/db/repository/index" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/mlogger" "github.com/tech/sendico/pkg/mservice" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.uber.org/zap" ) const ( defaultWalletPageSize int64 = 50 maxWalletPageSize int64 = 200 ) type Wallets struct { logger mlogger.Logger walletRepo repository.Repository balanceRepo repository.Repository } // NewWallets constructs a Mongo-backed wallets store. func NewWallets(logger mlogger.Logger, db *mongo.Database) (*Wallets, error) { if db == nil { return nil, merrors.InvalidArgument("mongo database is nil") } walletRepo := repository.CreateMongoRepository(db, mservice.ChainWallets) walletIndexes := []*ri.Definition{ { Keys: []ri.Key{{Field: "walletRef", Sort: ri.Asc}}, Unique: true, }, { Keys: []ri.Key{{Field: "idempotencyKey", Sort: ri.Asc}}, Unique: true, }, { Keys: []ri.Key{{Field: "depositAddress", Sort: ri.Asc}}, Unique: true, }, { Keys: []ri.Key{{Field: "organizationRef", Sort: ri.Asc}, {Field: "ownerRef", Sort: ri.Asc}}, }, } for _, def := range walletIndexes { if err := walletRepo.CreateIndex(def); err != nil { logger.Error("failed to ensure wallet index", zap.String("collection", walletRepo.Collection()), zap.Error(err)) return nil, err } } balanceRepo := repository.CreateMongoRepository(db, mservice.ChainWalletBalances) balanceIndexes := []*ri.Definition{ { Keys: []ri.Key{{Field: "walletRef", Sort: ri.Asc}}, Unique: true, }, } for _, def := range balanceIndexes { if err := balanceRepo.CreateIndex(def); err != nil { logger.Error("failed to ensure wallet balance index", zap.String("collection", balanceRepo.Collection()), zap.Error(err)) return nil, err } } childLogger := logger.Named("wallets") childLogger.Debug("wallet stores initialised") return &Wallets{ logger: childLogger, walletRepo: walletRepo, balanceRepo: balanceRepo, }, nil } func (w *Wallets) Create(ctx context.Context, wallet *model.ManagedWallet) (*model.ManagedWallet, error) { if wallet == nil { return nil, merrors.InvalidArgument("walletsStore: nil wallet") } wallet.Normalize() if strings.TrimSpace(wallet.WalletRef) == "" { return nil, merrors.InvalidArgument("walletsStore: empty walletRef") } if wallet.Status == "" { wallet.Status = model.ManagedWalletStatusActive } if strings.TrimSpace(wallet.IdempotencyKey) == "" { return nil, merrors.InvalidArgument("walletsStore: empty idempotencyKey") } if err := w.walletRepo.Insert(ctx, wallet, repository.Filter("idempotencyKey", wallet.IdempotencyKey)); err != nil { if errors.Is(err, merrors.ErrDataConflict) { w.logger.Debug("wallet already exists", zap.String("wallet_ref", wallet.WalletRef), zap.String("idempotency_key", wallet.IdempotencyKey)) return wallet, nil } return nil, err } w.logger.Debug("wallet created", zap.String("wallet_ref", wallet.WalletRef)) return wallet, nil } func (w *Wallets) Get(ctx context.Context, walletRef string) (*model.ManagedWallet, error) { walletRef = strings.TrimSpace(walletRef) if walletRef == "" { return nil, merrors.InvalidArgument("walletsStore: empty walletRef") } wallet := &model.ManagedWallet{} if err := w.walletRepo.FindOneByFilter(ctx, repository.Filter("walletRef", walletRef), wallet); err != nil { return nil, err } return wallet, nil } func (w *Wallets) List(ctx context.Context, filter model.ManagedWalletFilter) (*model.ManagedWalletList, error) { query := repository.Query() if org := strings.TrimSpace(filter.OrganizationRef); org != "" { query = query.Filter(repository.Field("organizationRef"), org) } if owner := strings.TrimSpace(filter.OwnerRef); owner != "" { query = query.Filter(repository.Field("ownerRef"), owner) } if network := strings.TrimSpace(filter.Network); network != "" { query = query.Filter(repository.Field("network"), strings.ToLower(network)) } if token := strings.TrimSpace(filter.TokenSymbol); token != "" { query = query.Filter(repository.Field("tokenSymbol"), strings.ToUpper(token)) } if cursor := strings.TrimSpace(filter.Cursor); cursor != "" { if oid, err := primitive.ObjectIDFromHex(cursor); err == nil { query = query.Comparison(repository.IDField(), builder.Gt, oid) } else { w.logger.Warn("ignoring invalid wallet cursor", zap.String("cursor", cursor), zap.Error(err)) } } limit := sanitizeWalletLimit(filter.Limit) fetchLimit := limit + 1 query = query.Sort(repository.IDField(), true).Limit(&fetchLimit) wallets := make([]*model.ManagedWallet, 0, fetchLimit) decoder := func(cur *mongo.Cursor) error { item := &model.ManagedWallet{} if err := cur.Decode(item); err != nil { return err } wallets = append(wallets, item) return nil } if err := w.walletRepo.FindManyByFilter(ctx, query, decoder); err != nil && !errors.Is(err, merrors.ErrNoData) { return nil, err } nextCursor := "" if int64(len(wallets)) == fetchLimit { last := wallets[len(wallets)-1] nextCursor = last.ID.Hex() wallets = wallets[:len(wallets)-1] } return &model.ManagedWalletList{ Items: wallets, NextCursor: nextCursor, }, nil } func (w *Wallets) SaveBalance(ctx context.Context, balance *model.WalletBalance) error { if balance == nil { return merrors.InvalidArgument("walletsStore: nil balance") } balance.Normalize() if strings.TrimSpace(balance.WalletRef) == "" { return merrors.InvalidArgument("walletsStore: empty walletRef for balance") } if balance.CalculatedAt.IsZero() { balance.CalculatedAt = time.Now().UTC() } existing := &model.WalletBalance{} err := w.balanceRepo.FindOneByFilter(ctx, repository.Filter("walletRef", balance.WalletRef), existing) switch { case err == nil: existing.Available = balance.Available existing.PendingInbound = balance.PendingInbound existing.PendingOutbound = balance.PendingOutbound existing.CalculatedAt = balance.CalculatedAt if err := w.balanceRepo.Update(ctx, existing); err != nil { return err } return nil case errors.Is(err, merrors.ErrNoData): if err := w.balanceRepo.Insert(ctx, balance, repository.Filter("walletRef", balance.WalletRef)); err != nil { return err } return nil default: return err } } func (w *Wallets) GetBalance(ctx context.Context, walletRef string) (*model.WalletBalance, error) { walletRef = strings.TrimSpace(walletRef) if walletRef == "" { return nil, merrors.InvalidArgument("walletsStore: empty walletRef") } balance := &model.WalletBalance{} if err := w.balanceRepo.FindOneByFilter(ctx, repository.Filter("walletRef", walletRef), balance); err != nil { return nil, err } return balance, nil } func sanitizeWalletLimit(requested int32) int64 { if requested <= 0 { return defaultWalletPageSize } if requested > int32(maxWalletPageSize) { return maxWalletPageSize } return int64(requested) } var _ storage.WalletsStore = (*Wallets)(nil)