service backend
This commit is contained in:
17
api/pkg/db/account/account.go
Executable file
17
api/pkg/db/account/account.go
Executable file
@@ -0,0 +1,17 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// DB is the interface which must be implemented by all db drivers
|
||||
type DB interface {
|
||||
template.DB[*model.Account]
|
||||
GetByEmail(ctx context.Context, email string) (*model.Account, error)
|
||||
GetByToken(ctx context.Context, email string) (*model.Account, error)
|
||||
GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error)
|
||||
}
|
||||
11
api/pkg/db/config.go
Normal file
11
api/pkg/db/config.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package db
|
||||
|
||||
import "github.com/tech/sendico/pkg/model"
|
||||
|
||||
type DBDriver string
|
||||
|
||||
const (
|
||||
Mongo DBDriver = "mongodb"
|
||||
)
|
||||
|
||||
type Config = model.DriverConfig[DBDriver]
|
||||
65
api/pkg/db/connection.go
Normal file
65
api/pkg/db/connection.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
mongoDriver "go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
)
|
||||
|
||||
// Connection represents a low-level database connection lifecycle.
|
||||
type Connection interface {
|
||||
Disconnect(ctx context.Context) error
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// MongoConnection provides direct access to the underlying mongo client.
|
||||
type MongoConnection struct {
|
||||
client *mongoDriver.Client
|
||||
database string
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Client() *mongoDriver.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Database() *mongoDriver.Database {
|
||||
return c.client.Database(c.database)
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Disconnect(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.client.Disconnect(ctx)
|
||||
}
|
||||
|
||||
func (c *MongoConnection) Ping(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.client.Ping(ctx, readpref.Primary())
|
||||
}
|
||||
|
||||
// ConnectMongo returns a low-level MongoDB connection without constructing repositories.
|
||||
func ConnectMongo(logger mlogger.Logger, config *Config) (*MongoConnection, error) {
|
||||
if config == nil {
|
||||
return nil, merrors.InvalidArgument("database configuration is nil")
|
||||
}
|
||||
if config.Driver != Mongo {
|
||||
return nil, merrors.InvalidArgument("unsupported database driver: " + string(config.Driver))
|
||||
}
|
||||
|
||||
client, _, settings, err := mongoimpl.ConnectClient(logger, config.Settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MongoConnection{
|
||||
client: client,
|
||||
database: settings.Database,
|
||||
}, nil
|
||||
}
|
||||
41
api/pkg/db/factory.go
Normal file
41
api/pkg/db/factory.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
|
||||
"github.com/tech/sendico/pkg/db/invitation"
|
||||
"github.com/tech/sendico/pkg/db/organization"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
)
|
||||
|
||||
// Factory exposes high-level repositories used by application services.
|
||||
type Factory interface {
|
||||
NewRefreshTokensDB() (refreshtokens.DB, error)
|
||||
|
||||
NewAccountDB() (account.DB, error)
|
||||
NewOrganizationDB() (organization.DB, error)
|
||||
NewInvitationsDB() (invitation.DB, error)
|
||||
|
||||
NewRolesDB() (role.DB, error)
|
||||
NewPoliciesDB() (policy.DB, error)
|
||||
|
||||
TransactionFactory() transaction.Factory
|
||||
|
||||
Permissions() auth.Provider
|
||||
|
||||
CloseConnection()
|
||||
}
|
||||
|
||||
// NewConnection builds a Factory backed by the configured driver.
|
||||
func NewConnection(logger mlogger.Logger, config *Config) (Factory, error) {
|
||||
if config.Driver == Mongo {
|
||||
return mongoimpl.NewConnection(logger, config.Settings)
|
||||
}
|
||||
return nil, merrors.InvalidArgument("unknown database driver: " + string(config.Driver))
|
||||
}
|
||||
12
api/pkg/db/indexable/indexable.go
Normal file
12
api/pkg/db/indexable/indexable.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
}
|
||||
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AccountDB struct {
|
||||
template.DBImp[*model.Account]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*AccountDB, error) {
|
||||
p := &AccountDB{
|
||||
DBImp: *template.Create[*model.Account](logger, mservice.Accounts, db),
|
||||
}
|
||||
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "login", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create account database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetByToken(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Query().Filter(repository.Field("verifyToken"), email), &account)
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
@@ -0,0 +1,21 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error) {
|
||||
filter := repository.Query().Comparison(repository.IDField(), builder.In, refs)
|
||||
return mutil.GetObjects[model.Account](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
|
||||
func (db *AccountDB) GetByEmail(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Filter("login", email), &account)
|
||||
}
|
||||
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArchivableDB implements archive management for entities with model.Archivable embedded
|
||||
type ArchivableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getArchivable func(T) model.Archivable
|
||||
}
|
||||
|
||||
// NewArchivableDB creates a new ArchivableDB instance
|
||||
func NewArchivableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getArchivable func(T) model.Archivable,
|
||||
) *ArchivableDB[T] {
|
||||
return &ArchivableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getArchivable: getArchivable,
|
||||
}
|
||||
}
|
||||
|
||||
// SetArchived sets the archived status of an entity
|
||||
func (db *ArchivableDB[T]) SetArchived(ctx context.Context, objectRef primitive.ObjectID, archived bool) error {
|
||||
// Get current object to check current archived status
|
||||
obj := db.createEmpty()
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for setting archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract archivable from the object
|
||||
archivable := db.getArchivable(obj)
|
||||
currentArchived := archivable.IsArchived()
|
||||
if currentArchived == archived {
|
||||
db.logger.Debug("No change needed - same archived status",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Set the archived status
|
||||
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
|
||||
if err := db.repo.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to set archived status on object",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully set archived status on object",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsArchived checks if an entity is archived
|
||||
func (db *ArchivableDB[T]) IsArchived(ctx context.Context, objectRef primitive.ObjectID) (bool, error) {
|
||||
obj := db.createEmpty()
|
||||
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for checking archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef))
|
||||
return false, err
|
||||
}
|
||||
|
||||
archivable := db.getArchivable(obj)
|
||||
return archivable.IsArchived(), nil
|
||||
}
|
||||
|
||||
// Archive archives an entity (sets archived to true)
|
||||
func (db *ArchivableDB[T]) Archive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, true)
|
||||
}
|
||||
|
||||
// Unarchive unarchives an entity (sets archived to false)
|
||||
func (db *ArchivableDB[T]) Unarchive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, false)
|
||||
}
|
||||
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestArchivableObject represents a test object with archivable functionality
|
||||
type TestArchivableObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
model.ArchivableBase `bson:",inline" json:",inline"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) Collection() string {
|
||||
return "testArchivableObject"
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) GetArchivable() model.Archivable {
|
||||
return &t.ArchivableBase
|
||||
}
|
||||
|
||||
func TestArchivableDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MongoDB container (stable)
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("test"),
|
||||
mongodb.WithPassword("test"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := mongoContainer.Terminate(termCtx); err != nil {
|
||||
t.Logf("Failed to terminate container: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get MongoDB connection string
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to MongoDB
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI))
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Disconnect(context.Background()); err != nil {
|
||||
t.Logf("Failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping the database
|
||||
err = client.Ping(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create repository
|
||||
repo := repositoryimp.NewMongoRepository(client.Database("test_"+t.Name()), "testArchivableCollection")
|
||||
|
||||
// Create archivable DB
|
||||
archivableDB := NewArchivableDB(
|
||||
repo,
|
||||
zap.NewNop(),
|
||||
func() *TestArchivableObject { return &TestArchivableObject{} },
|
||||
func(obj *TestArchivableObject) model.Archivable { return obj.GetArchivable() },
|
||||
)
|
||||
|
||||
t.Run("SetArchived_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_NoChange", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err) // Should not error, just not change anything
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_Unarchive", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("IsArchived_True", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("IsArchived_False", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("Archive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Archive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("Unarchive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Unarchive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
}
|
||||
257
api/pkg/db/internal/mongo/db.go
Executable file
257
api/pkg/db/internal/mongo/db.go
Executable file
@@ -0,0 +1,257 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/rolesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/transactionimp"
|
||||
"github.com/tech/sendico/pkg/db/invitation"
|
||||
"github.com/tech/sendico/pkg/db/organization"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/config"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config represents configuration
|
||||
type Config struct {
|
||||
Port *string `mapstructure:"port"`
|
||||
PortEnv *string `mapstructure:"port_env"`
|
||||
User *string `mapstructure:"user"`
|
||||
UserEnv *string `mapstructure:"user_env"`
|
||||
PasswordEnv string `mapstructure:"password_env"`
|
||||
Database *string `mapstructure:"database"`
|
||||
DatabaseEnv *string `mapstructure:"database_env"`
|
||||
Host *string `mapstructure:"host"`
|
||||
HostEnv *string `mapstructure:"host_env"`
|
||||
AuthSource *string `mapstructure:"auth_source,omitempty"`
|
||||
AuthSourceEnv *string `mapstructure:"auth_source_env,omitempty"`
|
||||
AuthMechanism *string `mapstructure:"auth_mechanism,omitempty"`
|
||||
AuthMechanismEnv *string `mapstructure:"auth_mechanism_env,omitempty"`
|
||||
ReplicaSet *string `mapstructure:"replica_set,omitempty"`
|
||||
ReplicaSetEnv *string `mapstructure:"replica_set_env,omitempty"`
|
||||
Enforcer *auth.Config `mapstructure:"enforcer"`
|
||||
}
|
||||
|
||||
type DBSettings struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string
|
||||
AuthMechanism string
|
||||
ReplicaSet string
|
||||
}
|
||||
|
||||
func newProtectedDB[T any](
|
||||
db *DB,
|
||||
create func(ctx context.Context, logger mlogger.Logger, enforcer auth.Enforcer, pdb policy.DB, client *mongo.Database) (T, error),
|
||||
) (T, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
}
|
||||
|
||||
func Config2DBSettings(logger mlogger.Logger, config *Config) *DBSettings {
|
||||
p := new(DBSettings)
|
||||
p.Port = mutil.GetConfigValue(logger, "port", "port_env", config.Port, config.PortEnv)
|
||||
p.Database = mutil.GetConfigValue(logger, "database", "database_env", config.Database, config.DatabaseEnv)
|
||||
p.Password = os.Getenv(config.PasswordEnv)
|
||||
p.User = mutil.GetConfigValue(logger, "user", "user_env", config.User, config.UserEnv)
|
||||
p.Host = mutil.GetConfigValue(logger, "host", "host_env", config.Host, config.HostEnv)
|
||||
p.AuthSource = mutil.GetConfigValue(logger, "auth_source", "auth_source_env", config.AuthSource, config.AuthSourceEnv)
|
||||
p.AuthMechanism = mutil.GetConfigValue(logger, "auth_mechanism", "auth_mechanism_env", config.AuthMechanism, config.AuthMechanismEnv)
|
||||
p.ReplicaSet = mutil.GetConfigValue(logger, "replica_set", "replica_set_env", config.ReplicaSet, config.ReplicaSetEnv)
|
||||
return p
|
||||
}
|
||||
|
||||
func decodeConfig(logger mlogger.Logger, settings model.SettingsT) (*Config, *DBSettings, error) {
|
||||
var config Config
|
||||
if err := mapstructure.Decode(settings, &config); err != nil {
|
||||
logger.Warn("Failed to decode settings", zap.Error(err), zap.Any("settings", settings))
|
||||
return nil, nil, err
|
||||
}
|
||||
dbSettings := Config2DBSettings(logger, &config)
|
||||
return &config, dbSettings, nil
|
||||
}
|
||||
|
||||
func dialMongo(logger mlogger.Logger, dbSettings *DBSettings) (*mongo.Client, error) {
|
||||
cred := options.Credential{
|
||||
AuthMechanism: dbSettings.AuthMechanism,
|
||||
AuthSource: dbSettings.AuthSource,
|
||||
Username: dbSettings.User,
|
||||
Password: dbSettings.Password,
|
||||
}
|
||||
dbURI := buildURI(dbSettings)
|
||||
|
||||
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(dbURI).SetAuth(cred))
|
||||
if err != nil {
|
||||
logger.Error("Unable to connect to database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Connected successfully", zap.String("uri", dbURI))
|
||||
|
||||
if err := client.Ping(context.Background(), readpref.Primary()); err != nil {
|
||||
logger.Error("Unable to ping database", zap.Error(err))
|
||||
_ = client.Disconnect(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func ConnectClient(logger mlogger.Logger, settings model.SettingsT) (*mongo.Client, *Config, *DBSettings, error) {
|
||||
config, dbSettings, err := decodeConfig(logger, settings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
client, err := dialMongo(logger, dbSettings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return client, config, dbSettings, nil
|
||||
}
|
||||
|
||||
// DB represents the structure of the database
|
||||
type DB struct {
|
||||
logger mlogger.Logger
|
||||
config *DBSettings
|
||||
client *mongo.Client
|
||||
enforcer auth.Enforcer
|
||||
manager auth.Manager
|
||||
pdb policy.DB
|
||||
}
|
||||
|
||||
func (db *DB) db() *mongo.Database {
|
||||
return db.client.Database(db.config.Database)
|
||||
}
|
||||
|
||||
func (db *DB) NewAccountDB() (account.DB, error) {
|
||||
return accountdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewOrganizationDB() (organization.DB, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizationDB, err := organizationdb.Create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the concrete type - interface mismatch will be handled at runtime
|
||||
// TODO: Update organization.DB interface to match implementation signatures
|
||||
return organizationDB, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRefreshTokensDB() (refreshtokens.DB, error) {
|
||||
return refreshtokensdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewInvitationsDB() (invitation.DB, error) {
|
||||
return newProtectedDB(db, invitationdb.Create)
|
||||
}
|
||||
|
||||
func (db *DB) NewPoliciesDB() (policy.DB, error) {
|
||||
return db.pdb, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRolesDB() (role.DB, error) {
|
||||
return rolesdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) TransactionFactory() transaction.Factory {
|
||||
return transactionimp.CreateFactory(db.client)
|
||||
}
|
||||
|
||||
func (db *DB) Permissions() auth.Provider {
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Manager() auth.Manager {
|
||||
return db.manager
|
||||
}
|
||||
|
||||
func (db *DB) Enforcer() auth.Enforcer {
|
||||
return db.enforcer
|
||||
}
|
||||
|
||||
func (db *DB) GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error) {
|
||||
var policyDescription model.PolicyDescription
|
||||
return &policyDescription, db.pdb.FindOne(ctx, repository.Filter("resourceTypes", resource), &policyDescription)
|
||||
}
|
||||
|
||||
func (db *DB) CloseConnection() {
|
||||
if err := db.client.Disconnect(context.Background()); err != nil {
|
||||
db.logger.Warn("Failed to close connection", zap.Error(err))
|
||||
}
|
||||
db.logger.Info("Database connection closed")
|
||||
}
|
||||
|
||||
// NewConnection creates a new database connection
|
||||
func NewConnection(logger mlogger.Logger, settings model.SettingsT) (*DB, error) {
|
||||
client, config, dbSettings, err := ConnectClient(logger, settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &DB{
|
||||
logger: logger.Named("db"),
|
||||
config: dbSettings,
|
||||
client: client,
|
||||
}
|
||||
|
||||
cleanup := func(ctx context.Context) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
logger.Warn("Failed to close MongoDB connection", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
rdb, err := db.NewRolesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create roles database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.pdb, err = policiesdb.Create(db.logger, db.db()); err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.enforcer, db.manager, err = auth.CreateAuth(logger, db.client, db.db(), db.pdb, rdb, config.Enforcer); err != nil {
|
||||
db.logger.Warn("Failed to create permissions enforcer", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Indexable Implementation (Refactored)
|
||||
|
||||
## Overview
|
||||
|
||||
This package provides a refactored implementation of the `indexable.DB` interface that uses `mutil.GetObjects` for better consistency with the existing codebase. The implementation has been moved to the mongo folder and includes a factory for project indexable in the pkg/db folder.
|
||||
|
||||
## Structure
|
||||
|
||||
### 1. `api/pkg/db/internal/mongo/indexable/indexable.go`
|
||||
- **`ReorderTemplate[T]`**: Generic template function that uses `mutil.GetObjects` for fetching objects
|
||||
- **`IndexableDB`**: Base struct for creating concrete implementations
|
||||
- **Type-safe implementation**: Uses Go generics with proper type constraints
|
||||
|
||||
### 2. `api/pkg/db/project_indexable.go`
|
||||
- **`ProjectIndexableDB`**: Factory implementation for Project objects
|
||||
- **`NewProjectIndexableDB`**: Constructor function
|
||||
- **`ReorderTemplate`**: Duplicate of the mongo version for convenience
|
||||
|
||||
## Key Changes from Previous Implementation
|
||||
|
||||
### 1. **Uses `mutil.GetObjects`**
|
||||
```go
|
||||
// Old implementation (manual cursor handling)
|
||||
err = repo.FindManyByFilter(ctx, filter, func(cursor *mongo.Cursor) error {
|
||||
var obj T
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
objects = append(objects, obj)
|
||||
return nil
|
||||
})
|
||||
|
||||
// New implementation (using mutil.GetObjects)
|
||||
objects, err := mutil.GetObjects[T](
|
||||
ctx,
|
||||
logger,
|
||||
filterFunc().
|
||||
And(
|
||||
repository.IndexOpFilter(minIdx, builder.Gte),
|
||||
repository.IndexOpFilter(maxIdx, builder.Lte),
|
||||
),
|
||||
nil, nil, nil, // limit, offset, isArchived
|
||||
repo,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. **Moved to Mongo Folder**
|
||||
- Location: `api/pkg/db/internal/mongo/indexable/`
|
||||
- Consistent with other mongo implementations
|
||||
- Better organization within the codebase
|
||||
|
||||
### 3. **Added Factory in pkg/db**
|
||||
- Location: `api/pkg/db/project_indexable.go`
|
||||
- Provides easy access to project indexable functionality
|
||||
- Includes logger parameter for better error handling
|
||||
|
||||
## Usage
|
||||
|
||||
### Using the Factory (Recommended)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create a project indexable DB
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder a project
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
```
|
||||
|
||||
### Using the Template Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// Define helper functions
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
updateIndexable := func(p *model.Project, newIndex int) {
|
||||
p.Index = newIndex
|
||||
}
|
||||
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
filterFunc := func() builder.Query {
|
||||
return repository.OrgFilter(organizationRef)
|
||||
}
|
||||
|
||||
// Use the template
|
||||
err := indexable.ReorderTemplate(
|
||||
ctx,
|
||||
logger,
|
||||
repo,
|
||||
objectRef,
|
||||
newIndex,
|
||||
filterFunc,
|
||||
getIndexable,
|
||||
updateIndexable,
|
||||
createEmpty,
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits of Refactoring
|
||||
|
||||
1. **Consistency**: Uses `mutil.GetObjects` like other parts of the codebase
|
||||
2. **Better Error Handling**: Includes logger parameter for proper error logging
|
||||
3. **Organization**: Moved to appropriate folder structure
|
||||
4. **Factory Pattern**: Easy-to-use factory for common use cases
|
||||
5. **Type Safety**: Maintains compile-time type checking
|
||||
6. **Performance**: Leverages existing optimized `mutil.GetObjects` implementation
|
||||
|
||||
## Testing
|
||||
|
||||
### Mongo Implementation Tests
|
||||
```bash
|
||||
go test ./db/internal/mongo/indexable -v
|
||||
```
|
||||
|
||||
### Factory Tests
|
||||
```bash
|
||||
go test ./db -v
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
The refactored implementation is ready for integration with existing project reordering APIs. The factory pattern makes it easy to add reordering functionality to any service that needs to reorder projects within an organization.
|
||||
|
||||
## Migration from Old Implementation
|
||||
|
||||
If you were using the old implementation:
|
||||
|
||||
1. **Update imports**: Change from `api/pkg/db/internal/indexable` to `api/pkg/db`
|
||||
2. **Use factory**: Replace manual template usage with `NewProjectIndexableDB`
|
||||
3. **Add logger**: Include a logger parameter in your constructor calls
|
||||
4. **Update tests**: Use the new test structure if needed
|
||||
|
||||
The API remains the same, so existing code should work with minimal changes.
|
||||
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# Indexable Usage Guide
|
||||
|
||||
## Generic Implementation for Any Indexable Struct
|
||||
|
||||
The implementation is now **generic** and supports **any struct that embeds `model.Indexable`**!
|
||||
|
||||
- **Interface**: `api/pkg/db/indexable.go` - defines the contract
|
||||
- **Implementation**: `api/pkg/db/internal/mongo/indexable/` - generic implementation
|
||||
- **Factory**: `api/pkg/db/project_indexable.go` - convenient factory for projects
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Using the Generic Implementation Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// For any type that embeds model.Indexable, define helper functions:
|
||||
createEmpty := func() *YourType {
|
||||
return &YourType{}
|
||||
}
|
||||
|
||||
getIndexable := func(obj *YourType) *model.Indexable {
|
||||
return &obj.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB
|
||||
indexableDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with single filter parameter
|
||||
err := indexableDB.Reorder(ctx, objectID, newIndex, filter)
|
||||
```
|
||||
|
||||
### 2. Using the Project Factory (Recommended for Projects)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create project indexable DB (automatically applies org filter)
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder project (org filter applied automatically)
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, repository.Query())
|
||||
|
||||
// Reorder with additional filters (combined with org filter)
|
||||
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, additionalFilter)
|
||||
```
|
||||
|
||||
## Examples for Different Types
|
||||
|
||||
### Project IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
projectDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(ctx, projectID, 2, orgFilter)
|
||||
```
|
||||
|
||||
### Status IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Status {
|
||||
return &model.Status{}
|
||||
}
|
||||
|
||||
getIndexable := func(s *model.Status) *model.Indexable {
|
||||
return &s.Indexable
|
||||
}
|
||||
|
||||
statusDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
|
||||
statusDB.Reorder(ctx, statusID, 1, projectFilter)
|
||||
```
|
||||
|
||||
### Task IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
taskDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(ctx, taskID, 3, statusFilter)
|
||||
```
|
||||
|
||||
### Priority IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Priority {
|
||||
return &model.Priority{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Priority) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
priorityDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
priorityDB.Reorder(ctx, priorityID, 0, orgFilter)
|
||||
```
|
||||
|
||||
### Global Reordering (No Filter)
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
globalDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
// Reorders all items globally (empty filter)
|
||||
globalDB.Reorder(ctx, objectID, 5, repository.Query())
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ **Generic Support**
|
||||
- Works with **any struct** that embeds `model.Indexable`
|
||||
- Type-safe with compile-time checking
|
||||
- No hardcoded types
|
||||
|
||||
### ✅ **Single Filter Parameter**
|
||||
- **Simple**: Single `builder.Query` parameter instead of variadic `interface{}`
|
||||
- **Flexible**: Can incorporate any combination of filters
|
||||
- **Type-safe**: No runtime type assertions needed
|
||||
|
||||
### ✅ **Clean Architecture**
|
||||
- Interface separated from implementation
|
||||
- Generic implementation in internal package
|
||||
- Easy-to-use factories for common types
|
||||
|
||||
## How It Works
|
||||
|
||||
### Generic Algorithm
|
||||
1. **Get current index** using type-specific helper function
|
||||
2. **If no change needed** → return early
|
||||
3. **Apply filter** to scope affected items
|
||||
4. **Shift affected items** using `PatchMany` with `$inc`
|
||||
5. **Update target object** using `Patch` with `$set`
|
||||
|
||||
### Type-Safe Implementation
|
||||
```go
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// Single filter parameter - clean and simple
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Generic** - Works with any Indexable struct
|
||||
✅ **Type Safe** - Compile-time type checking
|
||||
✅ **Simple** - Single filter parameter instead of variadic interface{}
|
||||
✅ **Efficient** - Uses patches, not full updates
|
||||
✅ **Clean** - Interface separated from implementation
|
||||
|
||||
That's it! **Generic, type-safe, and simple** reordering for any Indexable struct with a single filter parameter.
|
||||
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Example usage of the generic IndexableDB with different types
|
||||
|
||||
// Example 1: Using with Project
|
||||
func ExampleProjectIndexableDB(repo repository.Repository, logger mlogger.Logger, organizationRef primitive.ObjectID) {
|
||||
// Define helper functions for Project
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Project
|
||||
projectDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(context.Background(), primitive.NewObjectID(), 2, orgFilter)
|
||||
}
|
||||
|
||||
// Example 3: Using with Task
|
||||
func ExampleTaskIndexableDB(repo repository.Repository, logger mlogger.Logger, statusRef primitive.ObjectID) {
|
||||
// Define helper functions for Task
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Task
|
||||
taskDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with status filter
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(context.Background(), primitive.NewObjectID(), 3, statusFilter)
|
||||
}
|
||||
|
||||
// Example 5: Using without any filter (global reordering)
|
||||
func ExampleGlobalIndexableDB(repo repository.Repository, logger mlogger.Logger) {
|
||||
// Define helper functions for any Indexable type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB without filters
|
||||
globalDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use without any filter - reorders all items globally
|
||||
globalDB.Reorder(context.Background(), primitive.NewObjectID(), 5, repository.Query())
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// IndexableDB implements db.IndexableDB interface with generic support
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// NewIndexableDB creates a new IndexableDB instance
|
||||
func NewIndexableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getIndexable func(T) *model.Indexable,
|
||||
) *IndexableDB[T] {
|
||||
return &IndexableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getIndexable: getIndexable,
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder implements the db.IndexableDB interface with single filter parameter
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
// Get current object to find its index
|
||||
obj := db.createEmpty()
|
||||
err := db.repo.Get(ctx, objectRef, obj)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to get object for reordering",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract index from the object
|
||||
indexable := db.getIndexable(obj)
|
||||
currentIndex := indexable.Index
|
||||
if currentIndex == newIndex {
|
||||
db.logger.Debug("No reordering needed - same index",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Simple reordering logic
|
||||
if currentIndex < newIndex {
|
||||
// Moving down: shift items between currentIndex+1 and newIndex up by -1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), -1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
|
||||
And(repository.IndexOpFilter(newIndex, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving down)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving down)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
} else {
|
||||
// Moving up: shift items between newIndex and currentIndex-1 down by +1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), 1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(newIndex, builder.Gte)).
|
||||
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving up)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving up)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
}
|
||||
|
||||
// Update the target object to new index
|
||||
patch := repository.Patch().Set(repository.IndexField(), newIndex)
|
||||
err = db.repo.Patch(ctx, objectRef, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to update target object index",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Info("Successfully reordered object",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("old_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil
|
||||
}
|
||||
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (repository.Repository, func()) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
db := client.Database("testdb")
|
||||
repo := repository.CreateMongoRepository(db, "projects")
|
||||
|
||||
cleanup := func() {
|
||||
disconnect(ctx, t, client)
|
||||
terminate(ctx, t, mongoContainer)
|
||||
}
|
||||
|
||||
return repo, cleanup
|
||||
}
|
||||
|
||||
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
t.Logf("failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func terminate(ctx context.Context, t *testing.T, container testcontainers.Container) {
|
||||
if err := container.Terminate(ctx); err != nil {
|
||||
t.Logf("failed to terminate MongoDB container: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexableDB_Reorder(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create test projects with different indices
|
||||
projects := []*model.Project{
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project A"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "A",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project B"},
|
||||
Indexable: model.Indexable{Index: 1},
|
||||
Mnemonic: "B",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project C"},
|
||||
Indexable: model.Indexable{Index: 2},
|
||||
Mnemonic: "C",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project D"},
|
||||
Indexable: model.Indexable{Index: 3},
|
||||
Mnemonic: "D",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Insert projects into database
|
||||
for _, project := range projects {
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_NoChange", func(t *testing.T) {
|
||||
// Test reordering to the same position (should be no-op)
|
||||
err := indexableDB.Reorder(ctx, projects[1].ID, 1, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify indices haven't changed
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveDown", func(t *testing.T) {
|
||||
// Move Project A (index 0) to index 2
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project A should now be at index 2
|
||||
// Project B should be at index 0
|
||||
// Project C should be at index 1
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project A (moved to index 2)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 0)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project C (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveUp", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Move Project C (index 2) to index 0
|
||||
err := indexableDB.Reorder(ctx, projects[2].ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project C should now be at index 0
|
||||
// Project A should be at index 1
|
||||
// Project B should be at index 2
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project C (moved to index 0)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project A (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 2)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_WithFilter", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test reordering with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, orgFilter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering worked with filter
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndexableDB_EdgeCases(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create a single project for edge case testing
|
||||
project := &model.Project{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Project"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "TEST",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
}
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_SingleItem", func(t *testing.T) {
|
||||
// Test reordering a single item (should work but have no effect)
|
||||
err := indexableDB.Reorder(ctx, project.ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, project.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_InvalidObjectID", func(t *testing.T) {
|
||||
// Test reordering with an invalid object ID
|
||||
invalidID := primitive.NewObjectID()
|
||||
err := indexableDB.Reorder(ctx, invalidID, 1, repository.Query())
|
||||
require.Error(t, err) // Should fail because object doesn't exist
|
||||
})
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Accept(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationAccepted)
|
||||
}
|
||||
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an invitation
|
||||
// Invitation supports archiving through PermissionBound embedding ArchivableBase
|
||||
func (db *InvitationDB) SetArchived(ctx context.Context, accountRef, organizationRef, invitationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, invitationRef, model.ActionUpdate)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to enforce archivation permission", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.DBImp.Logger.Debug("Permission denied for archivation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.AccessDenied(db.Collection, string(model.ActionUpdate), invitationRef)
|
||||
}
|
||||
|
||||
// Get the invitation first
|
||||
var invitation model.Invitation
|
||||
if err := db.Get(ctx, accountRef, invitationRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving invitation for archival", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the invitation's archived status
|
||||
invitation.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating invitation archived status", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: Currently no cascade dependencies for invitations
|
||||
// If cascade is enabled, we could add logic here for any future dependencies
|
||||
if cascade {
|
||||
db.DBImp.Logger.Debug("Cascade archiving requested but no dependencies to archive for invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an invitation
|
||||
// Invitations don't have cascade dependencies, so this is a simple deletion
|
||||
func (db *InvitationDB) DeleteCascade(ctx context.Context, accountRef, invitationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting invitation cascade deletion", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
|
||||
// Delete the invitation itself (no dependencies to cascade delete)
|
||||
if err := db.Delete(ctx, accountRef, invitationRef); err != nil {
|
||||
db.DBImp.Logger.Error("Error deleting invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil
|
||||
}
|
||||
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type InvitationDB struct {
|
||||
auth.ProtectedDBImp[*model.Invitation]
|
||||
}
|
||||
|
||||
func Create(
|
||||
ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*InvitationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Invitation](ctx, logger, pdb, enforcer, mservice.Invitations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unique email per organization
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: repository.OrgField().Build(), Sort: ri.Asc}, {Field: "description.email", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Error("Failed to create unique mnemonic index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ttl index
|
||||
ttl := int32(0) // zero ttl means expiration on date preset when inserting data
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},
|
||||
TTL: &ttl,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Warn("Failed to create ttl index in the invitations", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InvitationDB{ProtectedDBImp: *p}, nil
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Decline(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationDeclined)
|
||||
}
|
||||
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error) {
|
||||
roleField := repository.Field("role")
|
||||
orgField := repository.Field("organization")
|
||||
accField := repository.Field("account")
|
||||
empField := repository.Field("employee")
|
||||
regField := repository.Field("registrationAcc")
|
||||
descEmailField := repository.Field("description").Dot("email")
|
||||
pipeline := repository.Pipeline().
|
||||
// 0) Filter to exactly the invitation(s) you want
|
||||
Match(repository.IDFilter(invitationRef).And(repository.Filter("status", model.InvitationCreated))).
|
||||
// 1) Lookup the role document
|
||||
Lookup(
|
||||
mservice.Roles,
|
||||
repository.Field("roleRef"),
|
||||
repository.IDField(),
|
||||
roleField,
|
||||
).
|
||||
Unwind(repository.Ref(roleField)).
|
||||
// 2) Lookup the organization document
|
||||
Lookup(
|
||||
mservice.Organizations,
|
||||
repository.Field("organizationRef"),
|
||||
repository.IDField(),
|
||||
orgField,
|
||||
).
|
||||
Unwind(repository.Ref(orgField)).
|
||||
// 3) Lookup the account document
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
repository.Field("inviterRef"),
|
||||
repository.IDField(),
|
||||
accField,
|
||||
).
|
||||
Unwind(repository.Ref(accField)).
|
||||
/* 4) do we already have an account whose login == invitation.description ? */
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
descEmailField, // local field (invitation.description.email)
|
||||
repository.Field("login"), // foreign field (account.login)
|
||||
regField, // array: 0-length or ≥1
|
||||
).
|
||||
// 5) Projection
|
||||
Project(
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("description"),
|
||||
repository.Ref(accField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("avatarUrl"),
|
||||
repository.Ref(accField.Dot("avatarUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("description"),
|
||||
repository.Ref(orgField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("logoUrl"),
|
||||
repository.Ref(orgField.Dot("logoUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
roleField,
|
||||
repository.Ref(roleField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("invitation"), // ← left-hand side
|
||||
repository.Ref(repository.Field("description")), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("storable"), // ← left-hand side
|
||||
repository.RootRef(), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.ProjectionExpr(
|
||||
repository.Field("registrationRequired"),
|
||||
repository.Eq(
|
||||
repository.Size(repository.Value(repository.Ref(regField).Build())),
|
||||
repository.Literal(0),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
var res model.PublicInvitation
|
||||
haveResult := false
|
||||
decoder := func(cur *mongo.Cursor) error {
|
||||
if haveResult {
|
||||
// should never get here
|
||||
db.DBImp.Logger.Warn("Unexpected extra invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.Internal("Unexpected extra invitation found by reference")
|
||||
}
|
||||
if e := cur.Decode(&res); e != nil {
|
||||
db.DBImp.Logger.Warn("Failed to decode entity", zap.Error(e), zap.Any("data", cur.Current.String()))
|
||||
return e
|
||||
}
|
||||
haveResult = true
|
||||
return nil
|
||||
}
|
||||
if err := db.DBImp.Repository.Aggregate(ctx, pipeline, decoder); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to execute aggregation pipeline", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, err
|
||||
}
|
||||
if !haveResult {
|
||||
db.DBImp.Logger.Warn("No results fetched", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, merrors.NoData(fmt.Sprintf("Invitation %s not found", invitationRef.Hex()))
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mauth "github.com/tech/sendico/pkg/mutil/db/auth"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error) {
|
||||
res, err := mauth.GetProtectedObjects[model.Invitation](
|
||||
ctx,
|
||||
db.DBImp.Logger,
|
||||
accountRef, organizationRef, model.ActionRead,
|
||||
repository.OrgFilter(organizationRef),
|
||||
cursor,
|
||||
db.Enforcer,
|
||||
db.DBImp.Repository,
|
||||
)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
return []model.Invitation{}, nil
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) updateStatus(ctx context.Context, invitationRef primitive.ObjectID, newStatus model.InvitationStatus) error {
|
||||
// db.DBImp.Up
|
||||
var inv model.Invitation
|
||||
if err := db.DBImp.FindOne(ctx, repository.IDFilter(invitationRef), &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to fetch invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
inv.Status = newStatus
|
||||
if err := db.DBImp.Update(ctx, &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to update invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
22
api/pkg/db/internal/mongo/mongo.go
Normal file
22
api/pkg/db/internal/mongo/mongo.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func buildURI(s *DBSettings) string {
|
||||
u := &url.URL{
|
||||
Scheme: "mongodb",
|
||||
Host: s.Host,
|
||||
Path: "/" + url.PathEscape(s.Database), // /my%20db
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
if s.ReplicaSet != "" {
|
||||
q.Set("replicaSet", s.ReplicaSet)
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an organization and optionally cascades to projects, tasks, comments, and reactions
|
||||
func (db *OrganizationDB) SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
|
||||
// Get the organization first
|
||||
var organization model.Organization
|
||||
if err := db.Get(ctx, accountRef, organizationRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving organization for archival", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the organization's archived status
|
||||
organization.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating organization archived status", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an organization and all its related data (projects, tasks, comments, reactions, statuses)
|
||||
func (db *OrganizationDB) DeleteCascade(ctx context.Context, organizationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting organization deletion with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
|
||||
// Delete the organization itself
|
||||
if err := db.Unprotected().Delete(ctx, organizationRef); err != nil {
|
||||
db.DBImp.Logger.Warn("Error deleting organization", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted organization with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil
|
||||
}
|
||||
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) Create(ctx context.Context, _, _ primitive.ObjectID, org *model.Organization) error {
|
||||
if org == nil {
|
||||
return merrors.InvalidArgument("Organization object is nil")
|
||||
}
|
||||
org.SetID(primitive.NewObjectID())
|
||||
// Organizaiton reference must be set to the same value as own organization reference
|
||||
org.SetOrganizationRef(*org.GetID())
|
||||
return db.DBImp.Create(ctx, org)
|
||||
}
|
||||
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type OrganizationDB struct {
|
||||
auth.ProtectedDBImp[*model.Organization]
|
||||
}
|
||||
|
||||
func Create(ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*OrganizationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Organization](ctx, logger, pdb, enforcer, mservice.Organizations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &OrganizationDB{
|
||||
ProtectedDBImp: *p,
|
||||
}
|
||||
p.DBImp.SetDeleter(res.DeleteCascade)
|
||||
return res, nil
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/organizationdb/get.go
Normal file
12
api/pkg/db/internal/mongo/organizationdb/get.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) GetByRef(ctx context.Context, organizationRef primitive.ObjectID, org *model.Organization) error {
|
||||
return db.Unprotected().Get(ctx, organizationRef, org)
|
||||
}
|
||||
16
api/pkg/db/internal/mongo/organizationdb/list.go
Normal file
16
api/pkg/db/internal/mongo/organizationdb/list.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) List(ctx context.Context, accountRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.Organization, error) {
|
||||
filter := repository.Query().Comparison(repository.Field("members"), builder.Eq, accountRef)
|
||||
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, filter, cursor, db.DBImp.Repository)
|
||||
}
|
||||
14
api/pkg/db/internal/mongo/organizationdb/owned.go
Normal file
14
api/pkg/db/internal/mongo/organizationdb/owned.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) ListOwned(ctx context.Context, accountRef primitive.ObjectID) ([]model.Organization, error) {
|
||||
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, repository.Filter("ownerRef", accountRef), nil, db.DBImp.Repository)
|
||||
}
|
||||
562
api/pkg/db/internal/mongo/organizationdb/setarchived_test.go
Normal file
562
api/pkg/db/internal/mongo/organizationdb/setarchived_test.go
Normal file
@@ -0,0 +1,562 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/commentdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/projectdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/reactiondb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/statusdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/taskdb"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupSetArchivedTestDB(t *testing.T) (*OrganizationDB, *projectDBAdapter, *taskdb.TaskDB, *commentdb.CommentDB, *reactiondb.ReactionDB, func()) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MongoDB container
|
||||
mongodbContainer, err := mongodb.Run(ctx, "mongo:latest")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get connection string
|
||||
endpoint, err := mongodbContainer.Endpoint(ctx, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to MongoDB
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://"+endpoint))
|
||||
require.NoError(t, err)
|
||||
|
||||
db := client.Database("test_organization_setarchived")
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create mock enforcer and policy DB
|
||||
mockEnforcer := &mockSetArchivedEnforcer{}
|
||||
mockPolicyDB := &mockSetArchivedPolicyDB{}
|
||||
mockPGroupDB := &mockSetArchivedPGroupDB{}
|
||||
|
||||
// Create databases
|
||||
// We need to create a projectDB first, but we'll create a temporary one for organizationDB creation
|
||||
// Create temporary taskDB and statusDB for the temporary projectDB
|
||||
// Create temporary reactionDB and commentDB for the temporary taskDB
|
||||
tempReactionDB, err := reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempCommentDB, err := commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempReactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempTaskDB, err := taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempCommentDB, tempReactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempStatusDB, err := statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempProjectDB, err := projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, tempTaskDB, tempStatusDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create adapter for organizationDB creation
|
||||
tempProjectDBAdapter := &projectDBAdapter{
|
||||
ProjectDB: tempProjectDB,
|
||||
taskDB: tempTaskDB,
|
||||
commentDB: tempCommentDB,
|
||||
reactionDB: tempReactionDB,
|
||||
statusDB: tempStatusDB,
|
||||
}
|
||||
|
||||
organizationDB, err := Create(ctx, logger, mockEnforcer, mockPolicyDB, tempProjectDBAdapter, mockPGroupDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
var projectDB *projectdb.ProjectDB
|
||||
var taskDB *taskdb.TaskDB
|
||||
var commentDB *commentdb.CommentDB
|
||||
var reactionDB *reactiondb.ReactionDB
|
||||
var statusDB *statusdb.StatusDB
|
||||
|
||||
// Create databases in dependency order
|
||||
reactionDB, err = reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
commentDB, err = commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, reactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
taskDB, err = taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, commentDB, reactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
statusDB, err = statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
projectDB, err = projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, taskDB, statusDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create adapter for the actual projectDB
|
||||
projectDBAdapter := &projectDBAdapter{
|
||||
ProjectDB: projectDB,
|
||||
taskDB: taskDB,
|
||||
commentDB: commentDB,
|
||||
reactionDB: reactionDB,
|
||||
statusDB: statusDB,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
client.Disconnect(context.Background())
|
||||
mongodbContainer.Terminate(ctx)
|
||||
}
|
||||
|
||||
return organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup
|
||||
}
|
||||
|
||||
// projectDBAdapter adapts projectdb.ProjectDB to project.DB interface for testing
|
||||
type projectDBAdapter struct {
|
||||
*projectdb.ProjectDB
|
||||
taskDB *taskdb.TaskDB
|
||||
commentDB *commentdb.CommentDB
|
||||
reactionDB *reactiondb.ReactionDB
|
||||
statusDB *statusdb.StatusDB
|
||||
}
|
||||
|
||||
// DeleteCascade implements the project.DB interface
|
||||
func (a *projectDBAdapter) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
|
||||
// Call the concrete implementation
|
||||
return a.ProjectDB.DeleteCascade(ctx, projectRef)
|
||||
}
|
||||
|
||||
// SetArchived implements the project.DB interface
|
||||
func (a *projectDBAdapter) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
|
||||
// Use the stored dependencies for the concrete implementation
|
||||
return a.ProjectDB.SetArchived(ctx, accountRef, organizationRef, projectRef, archived, cascade)
|
||||
}
|
||||
|
||||
// List implements the project.DB interface
|
||||
func (a *projectDBAdapter) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
|
||||
return a.ProjectDB.List(ctx, accountRef, organizationRef, primitive.NilObjectID, cursor)
|
||||
}
|
||||
|
||||
// Previews implements the project.DB interface
|
||||
func (a *projectDBAdapter) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
|
||||
return a.ProjectDB.Previews(ctx, accountRef, organizationRef, projectRefs, cursor, assigneeRefs, reporterRefs)
|
||||
}
|
||||
|
||||
// DeleteProject implements the project.DB interface
|
||||
func (a *projectDBAdapter) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
|
||||
// Call the concrete implementation with the organizationRef
|
||||
return a.ProjectDB.DeleteProject(ctx, accountRef, organizationRef, projectRef, migrateToRef)
|
||||
}
|
||||
|
||||
// RemoveTagFromProjects implements the project.DB interface
|
||||
func (a *projectDBAdapter) RemoveTagFromProjects(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
|
||||
// Call the concrete implementation
|
||||
return a.ProjectDB.RemoveTagFromProjects(ctx, accountRef, organizationRef, tagRef)
|
||||
}
|
||||
|
||||
// Mock implementations for SetArchived testing
|
||||
type mockSetArchivedEnforcer struct{}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) Enforce(ctx context.Context, permissionRef, accountRef, orgRef, objectRef primitive.ObjectID, action model.Action) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) EnforceBatch(ctx context.Context, objectRefs []model.PermissionBoundStorable, accountRef primitive.ObjectID, action model.Action) (map[primitive.ObjectID]bool, error) {
|
||||
// Allow all objects for testing
|
||||
result := make(map[primitive.ObjectID]bool)
|
||||
for _, obj := range objectRefs {
|
||||
result[*obj.GetID()] = true
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) GetRoles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) GetPermissions(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
type mockSetArchivedPolicyDB struct{}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Create(ctx context.Context, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Get(ctx context.Context, policyRef primitive.ObjectID, result *model.PolicyDescription) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) InsertMany(ctx context.Context, objects []*model.PolicyDescription) error { return nil }
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Update(ctx context.Context, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Delete(ctx context.Context, policyRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) DeleteMany(ctx context.Context, filter builder.Query) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) FindOne(ctx context.Context, filter builder.Query, result *model.PolicyDescription) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedPolicyDB) Collection() string { return "" }
|
||||
func (m *mockSetArchivedPolicyDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) DeleteCascade(ctx context.Context, policyRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockSetArchivedPGroupDB struct{}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedPGroupDB) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []*model.PriorityGroup) error { return nil }
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Get(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, result *model.PriorityGroup) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Update(ctx context.Context, accountRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Delete(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) DeleteCascadeAuth(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Patch(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Unprotected() template.DB[*model.PriorityGroup] {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.PriorityGroup, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.PriorityGroup, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) DeleteCascade(ctx context.Context, statusRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) SetArchived(ctx context.Context, accountRef, organizationRef, statusRef primitive.ObjectID, archived, cascade bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Reorder(ctx context.Context, accountRef, priorityGroupRef primitive.ObjectID, oldIndex, newIndex int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mock project DB for statusdb creation
|
||||
type mockSetArchivedProjectDB struct{}
|
||||
|
||||
func (m *mockSetArchivedProjectDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, project *model.Project) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Get(ctx context.Context, accountRef, projectRef primitive.ObjectID, result *model.Project) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Update(ctx context.Context, accountRef primitive.ObjectID, project *model.Project) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Delete(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteCascadeAuth(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Unprotected() template.DB[*model.Project] { return nil }
|
||||
func (m *mockSetArchivedProjectDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]*model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]*model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestOrganizationDB_SetArchived(t *testing.T) {
|
||||
organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup := setupSetArchivedTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
accountRef := primitive.NewObjectID()
|
||||
|
||||
t.Run("SetArchived_OrganizationWithProjectsTasksCommentsAndReactions_Cascade", func(t *testing.T) {
|
||||
// Create an organization using unprotected DB
|
||||
organization := &model.Organization{
|
||||
OrganizationBase: model.OrganizationBase{
|
||||
Describable: model.Describable{Name: "Test Organization for Archive"},
|
||||
TimeZone: "UTC",
|
||||
},
|
||||
}
|
||||
organization.ID = primitive.NewObjectID()
|
||||
|
||||
err := organizationDB.Create(ctx, accountRef, organization.ID, organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a project for the organization using unprotected DB
|
||||
project := &model.Project{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Project"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "TEST",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
}
|
||||
project.ID = primitive.NewObjectID()
|
||||
|
||||
err = projectDBAdapter.Unprotected().Create(ctx, project)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a task for the project using unprotected DB
|
||||
task := &model.Task{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Task for Archive"},
|
||||
ProjectRef: project.ID,
|
||||
}
|
||||
task.ID = primitive.NewObjectID()
|
||||
|
||||
err = taskDB.Unprotected().Create(ctx, task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comments for the task using unprotected DB
|
||||
comment := &model.Comment{
|
||||
CommentBase: model.CommentBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
AuthorRef: accountRef,
|
||||
TaskRef: task.ID,
|
||||
Content: "Test Comment for Archive",
|
||||
},
|
||||
}
|
||||
comment.ID = primitive.NewObjectID()
|
||||
|
||||
err = commentDB.Unprotected().Create(ctx, comment)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create reaction for the comment using unprotected DB
|
||||
reaction := &model.Reaction{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Type: "like",
|
||||
AuthorRef: accountRef,
|
||||
CommentRef: comment.ID,
|
||||
}
|
||||
reaction.ID = primitive.NewObjectID()
|
||||
|
||||
err = reactionDB.Unprotected().Create(ctx, reaction)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are not archived initially
|
||||
var retrievedOrganization model.Organization
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedOrganization.IsArchived())
|
||||
|
||||
var retrievedProject model.Project
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedProject.IsArchived())
|
||||
|
||||
var retrievedTask model.Task
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedTask.IsArchived())
|
||||
|
||||
var retrievedComment model.Comment
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedComment.IsArchived())
|
||||
|
||||
// Archive organization with cascade
|
||||
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, true, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are archived due to cascade
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedOrganization.IsArchived())
|
||||
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedProject.IsArchived())
|
||||
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedTask.IsArchived())
|
||||
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedComment.IsArchived())
|
||||
|
||||
// Verify reaction still exists (reactions don't support archiving)
|
||||
var retrievedReaction model.Reaction
|
||||
err = reactionDB.Unprotected().Get(ctx, reaction.ID, &retrievedReaction)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive organization with cascade
|
||||
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are unarchived
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedOrganization.IsArchived())
|
||||
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedProject.IsArchived())
|
||||
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedTask.IsArchived())
|
||||
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedComment.IsArchived())
|
||||
|
||||
// Clean up
|
||||
err = reactionDB.Unprotected().Delete(ctx, reaction.ID)
|
||||
require.NoError(t, err)
|
||||
err = commentDB.Unprotected().Delete(ctx, comment.ID)
|
||||
require.NoError(t, err)
|
||||
err = taskDB.Unprotected().Delete(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
err = projectDBAdapter.Unprotected().Delete(ctx, project.ID)
|
||||
require.NoError(t, err)
|
||||
err = organizationDB.Delete(ctx, accountRef, organization.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetArchived_NonExistentOrganization", func(t *testing.T) {
|
||||
// Try to archive non-existent organization
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
err := organizationDB.SetArchived(ctx, accountRef, nonExistentID, true, true)
|
||||
assert.Error(t, err)
|
||||
// Could be either no data or access denied error depending on the permission system
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData) || errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
20
api/pkg/db/internal/mongo/policiesdb/all.go
Normal file
20
api/pkg/db/internal/mongo/policiesdb/all.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
// all documents
|
||||
filter := repository.Query().Or(
|
||||
repository.Filter(storable.OrganizationRefField, nil),
|
||||
repository.OrgFilter(organizationRef),
|
||||
)
|
||||
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
13
api/pkg/db/internal/mongo/policiesdb/builtin.go
Normal file
13
api/pkg/db/internal/mongo/policiesdb/builtin.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
|
||||
return db.FindOne(ctx, repository.Filter("resourceTypes", resourceType), policy)
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/policiesdb/db.go
Normal file
21
api/pkg/db/internal/mongo/policiesdb/db.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type PoliciesDB struct {
|
||||
template.DBImp[*model.PolicyDescription]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*PoliciesDB, error) {
|
||||
p := &PoliciesDB{
|
||||
DBImp: *template.Create[*model.PolicyDescription](logger, mservice.Policies, db),
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
353
api/pkg/db/internal/mongo/policiesdb/db_test.go
Normal file
353
api/pkg/db/internal/mongo/policiesdb/db_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package policiesdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Your internal packages
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
// Model package (contains PolicyDescription + Describable)
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
// Testcontainers
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Helper to terminate container
|
||||
func terminate(t *testing.T, ctx context.Context, container *mongodb.MongoDBContainer) {
|
||||
err := container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to terminate MongoDB container")
|
||||
}
|
||||
|
||||
// Helper to disconnect client
|
||||
func disconnect(t *testing.T, ctx context.Context, client *mongo.Client) {
|
||||
err := client.Disconnect(context.Background())
|
||||
require.NoError(t, err, "failed to disconnect from MongoDB")
|
||||
}
|
||||
|
||||
// Helper to drop the Policies collection
|
||||
func cleanupCollection(t *testing.T, ctx context.Context, db *mongo.Database) {
|
||||
// The actual collection name is typically the value returned by
|
||||
// (&model.PolicyDescription{}).Collection(), or something similar.
|
||||
// Make sure it matches what your code uses (often "policies" or "policyDescription").
|
||||
err := db.Collection((&model.PolicyDescription{}).Collection()).Drop(ctx)
|
||||
require.NoError(t, err, "failed to drop collection between sub-tests")
|
||||
}
|
||||
|
||||
func TestPoliciesDB(t *testing.T) {
|
||||
// Create context with reasonable timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Start MongoDB test container
|
||||
mongoC, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(t, ctx, mongoC)
|
||||
|
||||
// Get connection URI
|
||||
mongoURI, err := mongoC.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get connection string")
|
||||
|
||||
// Connect client
|
||||
clientOpts := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOpts)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(t, ctx, client)
|
||||
|
||||
// Create test DB
|
||||
db := client.Database("testdb")
|
||||
|
||||
// Use a no-op logger (or real logger if you prefer)
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create an instance of PoliciesDB
|
||||
pdb, err := policiesdb.Create(logger, db)
|
||||
require.NoError(t, err, "unexpected error creating PoliciesDB")
|
||||
|
||||
// ---------------------------------------------------------
|
||||
// Each sub-test below starts by dropping the collection.
|
||||
// ---------------------------------------------------------
|
||||
|
||||
t.Run("CreateAndGet", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db) // ensure no leftover data
|
||||
|
||||
desc := "Test policy description"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "TestPolicy",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
result := &model.PolicyDescription{}
|
||||
err := pdb.Get(ctx, policy.ID, result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, policy.ID, result.ID)
|
||||
assert.Equal(t, "TestPolicy", result.Name)
|
||||
assert.NotNil(t, result.Description)
|
||||
assert.Equal(t, "Test policy description", *result.Description)
|
||||
})
|
||||
|
||||
t.Run("Get_NotFound", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
// Attempt to get a non-existent ID
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
result := &model.PolicyDescription{}
|
||||
err := pdb.Get(ctx, nonExistentID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
originalDesc := "Original description"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "OriginalName",
|
||||
Description: &originalDesc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
newDesc := "Updated description"
|
||||
policy.Name = "UpdatedName"
|
||||
policy.Description = &newDesc
|
||||
|
||||
err := pdb.Update(ctx, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated := &model.PolicyDescription{}
|
||||
err = pdb.Get(ctx, policy.ID, updated)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "UpdatedName", updated.Name)
|
||||
assert.NotNil(t, updated.Description)
|
||||
assert.Equal(t, "Updated description", *updated.Description)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc := "To be deleted"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "WillDelete",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
err := pdb.Delete(ctx, policy.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted := &model.PolicyDescription{}
|
||||
err = pdb.Get(ctx, policy.ID, deleted)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("DeleteMany", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc1 := "Will be deleted 1"
|
||||
desc2 := "Will be deleted 2"
|
||||
pol1 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "BatchDelete1",
|
||||
Description: &desc1,
|
||||
},
|
||||
}
|
||||
pol2 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "BatchDelete2",
|
||||
Description: &desc2,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, pol1))
|
||||
require.NoError(t, pdb.Create(ctx, pol2))
|
||||
|
||||
q := repository.Query().RegEx(repository.Field("description"), "^Will be deleted", "")
|
||||
err := pdb.DeleteMany(ctx, q)
|
||||
require.NoError(t, err)
|
||||
|
||||
res1 := &model.PolicyDescription{}
|
||||
err1 := pdb.Get(ctx, pol1.ID, res1)
|
||||
assert.Error(t, err1)
|
||||
assert.True(t, errors.Is(err1, merrors.ErrNoData))
|
||||
|
||||
res2 := &model.PolicyDescription{}
|
||||
err2 := pdb.Get(ctx, pol2.ID, res2)
|
||||
assert.Error(t, err2)
|
||||
assert.True(t, errors.Is(err2, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("FindOne", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc := "Unique find test"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "FindOneTest",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
// Match by name == "FindOneTest"
|
||||
q := repository.Query().Comparison(repository.Field("name"), builder.Eq, "FindOneTest")
|
||||
|
||||
found := &model.PolicyDescription{}
|
||||
err := pdb.FindOne(ctx, q, found)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, policy.ID, found.ID)
|
||||
assert.Equal(t, "FindOneTest", found.Name)
|
||||
assert.NotNil(t, found.Description)
|
||||
assert.Equal(t, "Unique find test", *found.Description)
|
||||
})
|
||||
|
||||
t.Run("All", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
// Insert some policies (orgA, orgB, nil org)
|
||||
orgA := primitive.NewObjectID()
|
||||
orgB := primitive.NewObjectID()
|
||||
|
||||
descA := "Org A policy"
|
||||
policyA := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyA",
|
||||
Description: &descA,
|
||||
},
|
||||
OrganizationRef: &orgA, // belongs to orgA
|
||||
}
|
||||
descB := "Org B policy"
|
||||
policyB := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyB",
|
||||
Description: &descB,
|
||||
},
|
||||
OrganizationRef: &orgB, // belongs to orgB
|
||||
}
|
||||
descNil := "No org policy"
|
||||
policyNil := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyNil",
|
||||
Description: &descNil,
|
||||
},
|
||||
// nil => built-in
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policyA))
|
||||
require.NoError(t, pdb.Create(ctx, policyB))
|
||||
require.NoError(t, pdb.Create(ctx, policyNil))
|
||||
|
||||
// Suppose the requirement is: "All" returns
|
||||
// - policies for the requested org
|
||||
// - plus built-in (nil) ones
|
||||
resultsA, err := pdb.All(ctx, orgA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resultsA, 2) // orgA + built-in
|
||||
|
||||
var idsA []primitive.ObjectID
|
||||
for _, r := range resultsA {
|
||||
idsA = append(idsA, r.ID)
|
||||
}
|
||||
assert.Contains(t, idsA, policyA.ID)
|
||||
assert.Contains(t, idsA, policyNil.ID)
|
||||
assert.NotContains(t, idsA, policyB.ID)
|
||||
|
||||
resultsB, err := pdb.All(ctx, orgB)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resultsB, 2) // orgB + built-in
|
||||
|
||||
var idsB []primitive.ObjectID
|
||||
for _, r := range resultsB {
|
||||
idsB = append(idsB, r.ID)
|
||||
}
|
||||
assert.Contains(t, idsB, policyB.ID)
|
||||
assert.Contains(t, idsB, policyNil.ID)
|
||||
assert.NotContains(t, idsB, policyA.ID)
|
||||
})
|
||||
|
||||
t.Run("Policies", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc1 := "PolicyOne"
|
||||
pol1 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyOne",
|
||||
Description: &desc1,
|
||||
},
|
||||
}
|
||||
desc2 := "PolicyTwo"
|
||||
pol2 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyTwo",
|
||||
Description: &desc2,
|
||||
},
|
||||
}
|
||||
desc3 := "PolicyThree"
|
||||
pol3 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyThree",
|
||||
Description: &desc3,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, pol1))
|
||||
require.NoError(t, pdb.Create(ctx, pol2))
|
||||
require.NoError(t, pdb.Create(ctx, pol3))
|
||||
|
||||
// 1) Request pol1, pol2
|
||||
results12, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results12, 2)
|
||||
// IDs might be out of order, so we do a set-like check
|
||||
var set12 []primitive.ObjectID
|
||||
for _, r := range results12 {
|
||||
set12 = append(set12, r.ID)
|
||||
}
|
||||
assert.Contains(t, set12, pol1.ID)
|
||||
assert.Contains(t, set12, pol2.ID)
|
||||
|
||||
// 2) Request pol1, pol3, plus a random ID
|
||||
fakeID := primitive.NewObjectID()
|
||||
results13Fake, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol3.ID, fakeID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results13Fake, 2) // pol1 + pol3 only
|
||||
var set13Fake []primitive.ObjectID
|
||||
for _, r := range results13Fake {
|
||||
set13Fake = append(set13Fake, r.ID)
|
||||
}
|
||||
assert.Contains(t, set13Fake, pol1.ID)
|
||||
assert.Contains(t, set13Fake, pol3.ID)
|
||||
|
||||
// 3) Request with empty slice => expect no results
|
||||
resultsEmpty, err := pdb.Policies(ctx, []primitive.ObjectID{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsEmpty, 0)
|
||||
})
|
||||
}
|
||||
18
api/pkg/db/internal/mongo/policiesdb/policies.go
Normal file
18
api/pkg/db/internal/mongo/policiesdb/policies.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
if len(refs) == 0 {
|
||||
return []model.PolicyDescription{}, nil
|
||||
}
|
||||
filter := repository.Query().In(repository.IDField(), refs)
|
||||
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) GetClient(ctx context.Context, clientID string) (*model.Client, error) {
|
||||
var client model.Client
|
||||
return &client, db.clients.FindOneByFilter(ctx, filterByClientId(clientID), &client)
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error {
|
||||
// First, try to find an existing token for this account/client/device combination
|
||||
var existing model.RefreshToken
|
||||
if rt.AccountRef == nil {
|
||||
return merrors.InvalidArgument("Account reference must have a vaild value")
|
||||
}
|
||||
if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
// No existing token, create a new one
|
||||
db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return db.DBImp.Create(ctx, rt)
|
||||
}
|
||||
db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err),
|
||||
zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Token already exists, update it with new values
|
||||
db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
if err := db.Patch(ctx, *existing.GetID(), patch); err != nil {
|
||||
db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the ID of the input token to match the existing one
|
||||
rt.SetID(*existing.GetID())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error {
|
||||
rt.LastUsedAt = time.Now()
|
||||
|
||||
// Use Patch instead of Update to avoid race conditions
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error {
|
||||
db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef))
|
||||
return db.DBImp.Delete(ctx, tokenRef)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error {
|
||||
var rt model.RefreshToken
|
||||
f := filterByAccount(accountRef, session)
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Use Patch to update the revocation status
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(IsRevokedField), true).
|
||||
Set(repository.Field(LastUsedAtField), time.Now())
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) {
|
||||
var rt model.RefreshToken
|
||||
f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken))
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if !errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to fetch refresh token", zap.Error(err),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if rt.ExpiresAt.Before(time.Now()) {
|
||||
db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID),
|
||||
zap.Time("expires_at", rt.ExpiresAt))
|
||||
return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID())
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
if rt.IsRevoked {
|
||||
db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
return nil, merrors.ErrNoData
|
||||
}
|
||||
|
||||
return &rt, nil
|
||||
}
|
||||
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RefreshTokenDB struct {
|
||||
template.DBImp[*model.RefreshToken]
|
||||
clients repository.Repository
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*RefreshTokenDB, error) {
|
||||
p := &RefreshTokenDB{
|
||||
DBImp: *template.Create[*model.RefreshToken](logger, mservice.RefreshTokens, db),
|
||||
clients: repository.CreateMongoRepository(db, mservice.Clients),
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "token", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add unique constraint on account/client/device combination
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "accountRef", Sort: ri.Asc},
|
||||
{Field: "clientId", Sort: ri.Asc},
|
||||
{Field: "deviceId", Sort: ri.Asc},
|
||||
},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique account/client/device index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: IsRevokedField, Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token revokation status index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.clients.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "clientId", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique client identifier index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package refreshtokensdb
|
||||
|
||||
const (
|
||||
ExpiresAtField = "expiresAt"
|
||||
IsRevokedField = "isRevoked"
|
||||
TokenField = "token"
|
||||
UserAgentField = "userAgent"
|
||||
IPAddressField = "ipAddress"
|
||||
LastUsedAtField = "lastUsedAt"
|
||||
)
|
||||
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func filterByClientId(clientID string) builder.Query {
|
||||
return repository.Query().Comparison(repository.Field("clientId"), builder.Eq, clientID)
|
||||
}
|
||||
|
||||
func filter(session *model.SessionIdentifier) builder.Query {
|
||||
filter := filterByClientId(session.ClientID)
|
||||
filter.And(
|
||||
repository.Query().Comparison(repository.Field("deviceId"), builder.Eq, session.DeviceID),
|
||||
repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false),
|
||||
)
|
||||
return filter
|
||||
}
|
||||
|
||||
func filterByAccount(accountRef primitive.ObjectID, session *model.SessionIdentifier) builder.Query {
|
||||
return filter(session).And(repository.Query().Comparison(repository.AccountField(), builder.Eq, accountRef))
|
||||
}
|
||||
@@ -0,0 +1,639 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package refreshtokensdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
factory "github.com/tech/sendico/pkg/mlogger/factory"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (*refreshtokensdb.RefreshTokenDB, func()) {
|
||||
// mark as helper for better test failure reporting
|
||||
t.Helper()
|
||||
|
||||
startCtx, startCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer startCancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(startCtx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(startCtx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(startCtx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
database := client.Database("test_refresh_tokens_" + t.Name())
|
||||
logger := factory.NewLogger(true)
|
||||
|
||||
db, err := refreshtokensdb.Create(logger, database)
|
||||
require.NoError(t, err, "failed to create refresh tokens db")
|
||||
|
||||
cleanup := func() {
|
||||
_ = database.Drop(context.Background())
|
||||
_ = client.Disconnect(context.Background())
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_ = mongoContainer.Terminate(termCtx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
func createTestRefreshToken(accountRef primitive.ObjectID, clientID, deviceID, token string) *model.RefreshToken {
|
||||
return &model.RefreshToken{
|
||||
ClientRefreshToken: model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
},
|
||||
AccountBoundBase: model.AccountBoundBase{
|
||||
AccountRef: &accountRef,
|
||||
},
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
IsRevoked: false,
|
||||
UserAgent: "TestUserAgent/1.0",
|
||||
IPAddress: "192.168.1.1",
|
||||
LastUsedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_AuthenticationFlow(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete_User_Authentication_Flow", func(t *testing.T) {
|
||||
// Setup: Create user and client
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "refresh_token_12345"
|
||||
|
||||
// Step 1: User logs in - create initial refresh token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: User uses refresh token to get new access token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrievedToken.AccountRef)
|
||||
assert.Equal(t, userID, *retrievedToken.AccountRef)
|
||||
assert.Equal(t, token, retrievedToken.RefreshToken)
|
||||
assert.False(t, retrievedToken.IsRevoked)
|
||||
|
||||
// Step 3: User logs out - revoke the token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 4: Try to use revoked token - should fail
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Manual_Token_Revocation_Workaround", func(t *testing.T) {
|
||||
// Test manual revocation by directly updating the token
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "manual_revoke_token_123"
|
||||
|
||||
// Step 1: Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: Manually revoke token by updating it directly
|
||||
refreshToken.IsRevoked = true
|
||||
err = db.Update(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 3: Try to use revoked token - should fail
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_MultiDeviceManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_With_Multiple_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "mobile-app"
|
||||
|
||||
// User logs in from phone
|
||||
phoneToken := createTestRefreshToken(userID, clientID, "phone-ios", "phone_token_123")
|
||||
err := db.Create(ctx, phoneToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from tablet
|
||||
tabletToken := createTestRefreshToken(userID, clientID, "tablet-android", "tablet_token_456")
|
||||
err = db.Create(ctx, tabletToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from desktop
|
||||
desktopToken := createTestRefreshToken(userID, clientID, "desktop-windows", "desktop_token_789")
|
||||
err = db.Create(ctx, desktopToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User wants to logout from all devices except current (phone)
|
||||
err = db.RevokeAll(ctx, userID, "phone-ios")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Phone should still work
|
||||
phoneCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "phone-ios",
|
||||
},
|
||||
RefreshToken: "phone_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, phoneCRT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tablet and desktop should be revoked
|
||||
tabletCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "tablet-android",
|
||||
},
|
||||
RefreshToken: "tablet_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, tabletCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
desktopCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "desktop-windows",
|
||||
},
|
||||
RefreshToken: "desktop_token_789",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, desktopCRT)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_TokenRotation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Rotation_On_Use", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
initialToken := "initial_token_123"
|
||||
|
||||
// Create initial token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, initialToken)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate small delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Use token - should update LastUsedAt
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: initialToken,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
// LastUsedAt is not updated by GetByCRT; validate token data instead
|
||||
assert.Equal(t, initialToken, retrievedToken.RefreshToken)
|
||||
|
||||
// Create new token with rotated value (simulating token rotation)
|
||||
newToken := "rotated_token_456"
|
||||
retrievedToken.RefreshToken = newToken
|
||||
err = db.Update(ctx, retrievedToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old token should no longer work
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
|
||||
// New token should work
|
||||
newCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: newToken,
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, newCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SessionReplacement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_Login_From_Same_Device_Twice", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-laptop"
|
||||
|
||||
// First login
|
||||
firstToken := createTestRefreshToken(userID, clientID, deviceID, "first_token_123")
|
||||
err := db.Create(ctx, firstToken)
|
||||
require.NoError(t, err)
|
||||
firstTokenID := *firstToken.GetID()
|
||||
|
||||
// Second login from same device - should replace existing token
|
||||
secondToken := createTestRefreshToken(userID, clientID, deviceID, "second_token_456")
|
||||
err = db.Create(ctx, secondToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should reuse the same database record
|
||||
assert.Equal(t, firstTokenID, *secondToken.GetID())
|
||||
|
||||
// First token should no longer work
|
||||
firstCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "first_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, firstCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Second token should work
|
||||
secondCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "second_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, secondCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ClientManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Client_CRUD_Operations", func(t *testing.T) {
|
||||
// Note: Client management is handled by a separate client database
|
||||
// This test verifies that refresh tokens work with different client IDs
|
||||
|
||||
userID := primitive.NewObjectID()
|
||||
|
||||
// Create refresh tokens for different clients
|
||||
webToken := createTestRefreshToken(userID, "web-app", "device1", "token1")
|
||||
err := db.Create(ctx, webToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
mobileToken := createTestRefreshToken(userID, "mobile-app", "device2", "token2")
|
||||
err = db.Create(ctx, mobileToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens can be retrieved by client ID
|
||||
webCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "web-app",
|
||||
DeviceID: "device1",
|
||||
},
|
||||
RefreshToken: "token1",
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, webCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device1", retrievedToken.DeviceID)
|
||||
|
||||
mobileCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "mobile-app",
|
||||
DeviceID: "device2",
|
||||
},
|
||||
RefreshToken: "token2",
|
||||
}
|
||||
|
||||
retrievedToken, err = db.GetByCRT(ctx, mobileCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "mobile-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device2", retrievedToken.DeviceID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SecurityScenarios(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Hijacking_Prevention", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
token := "hijacked_token_123"
|
||||
|
||||
// Create legitimate token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate security concern - revoke token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attacker tries to use hijacked token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Invalid_Token_Attempts", func(t *testing.T) {
|
||||
// Try to use completely invalid token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "invalid-client",
|
||||
DeviceID: "invalid-device",
|
||||
},
|
||||
RefreshToken: "invalid_token_123",
|
||||
}
|
||||
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ExpiredTokenHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Expired_Token_Cleanup", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "expired_token_123"
|
||||
|
||||
// Create token that expires in the past
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The token exists in database but is expired
|
||||
var storedToken model.RefreshToken
|
||||
err = db.Get(ctx, *refreshToken.GetID(), &storedToken)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storedToken.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Application should reject expired tokens
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ConcurrentAccess(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent_Token_Usage", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "concurrent_token_123"
|
||||
|
||||
// Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
// Simulate concurrent access
|
||||
done := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Both operations should succeed
|
||||
for i := 0; i < 2; i++ {
|
||||
err := <-done
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_EdgeCases(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Delete_Token_By_ID", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
refreshToken := createTestRefreshToken(userID, "web-app", "device-1", "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := *refreshToken.GetID()
|
||||
|
||||
// Delete token
|
||||
err = db.Delete(ctx, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should no longer exist
|
||||
var result model.RefreshToken
|
||||
err = db.Get(ctx, tokenID, &result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Revoke_Non_Existent_Token", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: "non-existent-client",
|
||||
DeviceID: "non-existent-device",
|
||||
}
|
||||
|
||||
err := db.Revoke(ctx, userID, session)
|
||||
// Should handle gracefully for non-existent tokens
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("RevokeAll_No_Other_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "only-device"
|
||||
|
||||
// Create single token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RevokeAll should not affect current device
|
||||
err = db.RevokeAll(ctx, userID, deviceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should still work
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "token_123",
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_DatabaseIndexes(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Unique_Token_Constraint", func(t *testing.T) {
|
||||
userID1 := primitive.NewObjectID()
|
||||
userID2 := primitive.NewObjectID()
|
||||
token := "duplicate_token_123"
|
||||
|
||||
// Create first token
|
||||
refreshToken1 := createTestRefreshToken(userID1, "client1", "device1", token)
|
||||
err := db.Create(ctx, refreshToken1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create second token with same token value - should fail due to unique index
|
||||
refreshToken2 := createTestRefreshToken(userID2, "client2", "device2", token)
|
||||
err = db.Create(ctx, refreshToken2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate")
|
||||
})
|
||||
|
||||
t.Run("Query_Performance_By_Revocation_Status", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
|
||||
// Create multiple tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
token := createTestRefreshToken(userID, clientID,
|
||||
fmt.Sprintf("device_%d", i), fmt.Sprintf("token_%d", i))
|
||||
if i%2 == 0 {
|
||||
token.IsRevoked = true
|
||||
}
|
||||
err := db.Create(ctx, token)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Query should efficiently filter by revocation status
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), userID).
|
||||
And(repository.Query().Comparison(repository.Field(refreshtokensdb.IsRevokedField), builder.Eq, false))
|
||||
|
||||
ids, err := db.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ids, 5) // Should find 5 non-revoked tokens
|
||||
})
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) RevokeAll(ctx context.Context, accountRef primitive.ObjectID, deviceID string) error {
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), accountRef).
|
||||
And(repository.Query().Comparison(repository.Field("deviceId"), builder.Ne, deviceID)).
|
||||
And(repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(ExpiresAtField), time.Now()).
|
||||
Set(repository.Field(IsRevokedField), true)
|
||||
|
||||
_, err := db.Repository.PatchMany(ctx, query, patch)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type literalAccumulatorImp struct {
|
||||
op builder.MongoOperation
|
||||
value any
|
||||
}
|
||||
|
||||
func (a *literalAccumulatorImp) Build() bson.D {
|
||||
return bson.D{{Key: string(a.op), Value: a.value}}
|
||||
}
|
||||
|
||||
func NewAccumulator(op builder.MongoOperation, value any) builder.Accumulator {
|
||||
return &literalAccumulatorImp{op: op, value: value}
|
||||
}
|
||||
|
||||
func AddToSet(value builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.AddToSet, value)
|
||||
}
|
||||
|
||||
func Size(value builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Size, value)
|
||||
}
|
||||
|
||||
func Ne(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Ne, left, right)
|
||||
}
|
||||
|
||||
func Sum(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Sum, value)
|
||||
}
|
||||
|
||||
func Avg(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Avg, value)
|
||||
}
|
||||
|
||||
func Min(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Min, value)
|
||||
}
|
||||
|
||||
func Max(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Max, value)
|
||||
}
|
||||
|
||||
func Eq(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Eq, left, right)
|
||||
}
|
||||
|
||||
func Gt(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Gt, left, right)
|
||||
}
|
||||
|
||||
func Add(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Add, left, right)
|
||||
}
|
||||
|
||||
func Subtract(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Subtract, left, right)
|
||||
}
|
||||
|
||||
func Multiply(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Multiply, left, right)
|
||||
}
|
||||
|
||||
func Divide(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Divide, left, right)
|
||||
}
|
||||
|
||||
type binaryAccumulator struct {
|
||||
op builder.MongoOperation
|
||||
left builder.Accumulator
|
||||
right builder.Accumulator
|
||||
}
|
||||
|
||||
func newBinaryAccumulator(op builder.MongoOperation, left, right builder.Accumulator) builder.Accumulator {
|
||||
return &binaryAccumulator{
|
||||
op: op,
|
||||
left: left,
|
||||
right: right,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *binaryAccumulator) Build() bson.D {
|
||||
args := []any{b.left.Build(), b.right.Build()}
|
||||
return bson.D{{Key: string(b.op), Value: args}}
|
||||
}
|
||||
102
api/pkg/db/internal/mongo/repositoryimp/builderimp/alias.go
Normal file
102
api/pkg/db/internal/mongo/repositoryimp/builderimp/alias.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type aliasImp struct {
|
||||
lhs builder.Field
|
||||
rhs any
|
||||
}
|
||||
|
||||
func (a *aliasImp) Field() builder.Field {
|
||||
return a.lhs
|
||||
}
|
||||
|
||||
func (a *aliasImp) Build() bson.D {
|
||||
return bson.D{{Key: a.lhs.Build(), Value: a.rhs}}
|
||||
}
|
||||
|
||||
// 1. Null alias (_id: null)
|
||||
func NewNullAlias(lhs builder.Field) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: nil}
|
||||
}
|
||||
|
||||
func NewAlias(lhs builder.Field, rhs any) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: rhs}
|
||||
}
|
||||
|
||||
// 2. Simple alias (_id: "$taskRef")
|
||||
func NewSimpleAlias(lhs, rhs builder.Field) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: rhs.Build()}
|
||||
}
|
||||
|
||||
// 3. Complex alias (_id: { aliasName: "$originalField", ... })
|
||||
type ComplexAlias struct {
|
||||
lhs builder.Field
|
||||
rhs []builder.Alias // Correcting handling of slice of aliases
|
||||
}
|
||||
|
||||
func (a *ComplexAlias) Field() builder.Field {
|
||||
return a.lhs
|
||||
}
|
||||
|
||||
func (a *ComplexAlias) Build() bson.D {
|
||||
fieldMap := bson.M{}
|
||||
|
||||
for _, alias := range a.rhs {
|
||||
// Each alias.Build() still returns a bson.D
|
||||
aliasDoc := alias.Build()
|
||||
|
||||
// 1. Marshal the ordered D into raw BSON bytes
|
||||
raw, err := bson.Marshal(aliasDoc)
|
||||
if err != nil {
|
||||
panic("Failed to marshal alias document: " + err.Error())
|
||||
}
|
||||
|
||||
// 2. Unmarshal those bytes into an unordered M
|
||||
var docM bson.M
|
||||
if err := bson.Unmarshal(raw, &docM); err != nil {
|
||||
panic("Failed to unmarshal alias document: " + err.Error())
|
||||
}
|
||||
|
||||
// Merge into our accumulator
|
||||
for k, v := range docM {
|
||||
fieldMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return bson.D{{Key: a.lhs.Build(), Value: fieldMap}}
|
||||
}
|
||||
|
||||
func NewComplexAlias(lhs builder.Field, rhs []builder.Alias) builder.Alias {
|
||||
return &ComplexAlias{lhs: lhs, rhs: rhs}
|
||||
}
|
||||
|
||||
type aliasesImp struct {
|
||||
aliases []builder.Alias
|
||||
}
|
||||
|
||||
func (a *aliasesImp) Field() builder.Field {
|
||||
if len(a.aliases) > 0 {
|
||||
return a.aliases[0].Field()
|
||||
}
|
||||
return NewFieldImp("")
|
||||
}
|
||||
|
||||
func (a *aliasesImp) Build() bson.D {
|
||||
results := make([]bson.D, 0)
|
||||
for _, alias := range a.aliases {
|
||||
results = append(results, alias.Build())
|
||||
}
|
||||
aliases := bson.D{}
|
||||
for _, r := range results {
|
||||
aliases = append(aliases, r...)
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
func NewAliases(aliases ...builder.Alias) builder.Alias {
|
||||
return &aliasesImp{aliases: aliases}
|
||||
}
|
||||
27
api/pkg/db/internal/mongo/repositoryimp/builderimp/array.go
Normal file
27
api/pkg/db/internal/mongo/repositoryimp/builderimp/array.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type arrayImp struct {
|
||||
elements []builder.Expression
|
||||
}
|
||||
|
||||
// Build renders the literal array:
|
||||
//
|
||||
// [ <expr1>, <expr2>, … ]
|
||||
func (b *arrayImp) Build() bson.A {
|
||||
arr := make(bson.A, len(b.elements))
|
||||
for i, expr := range b.elements {
|
||||
// each expr.Build() returns the raw value or sub‐expression
|
||||
arr[i] = expr.Build()
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// NewArray constructs a new array expression from the given sub‐expressions.
|
||||
func NewArray(exprs ...builder.Expression) *arrayImp {
|
||||
return &arrayImp{elements: exprs}
|
||||
}
|
||||
108
api/pkg/db/internal/mongo/repositoryimp/builderimp/expression.go
Normal file
108
api/pkg/db/internal/mongo/repositoryimp/builderimp/expression.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type literalExpression struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func NewLiteralExpression(value any) builder.Expression {
|
||||
return &literalExpression{value: value}
|
||||
}
|
||||
|
||||
func (e *literalExpression) Build() any {
|
||||
return bson.D{{Key: string(builder.Literal), Value: e.value}}
|
||||
}
|
||||
|
||||
type variadicExpression struct {
|
||||
op builder.MongoOperation
|
||||
parts []builder.Expression
|
||||
}
|
||||
|
||||
func (e *variadicExpression) Build() any {
|
||||
args := make([]any, 0, len(e.parts))
|
||||
for _, p := range e.parts {
|
||||
args = append(args, p.Build())
|
||||
}
|
||||
return bson.D{{Key: string(e.op), Value: args}}
|
||||
}
|
||||
|
||||
func newVariadicExpression(op builder.MongoOperation, exprs ...builder.Expression) builder.Expression {
|
||||
return &variadicExpression{
|
||||
op: op,
|
||||
parts: exprs,
|
||||
}
|
||||
}
|
||||
|
||||
func newBinaryExpression(op builder.MongoOperation, left, right builder.Expression) builder.Expression {
|
||||
return &variadicExpression{
|
||||
op: op,
|
||||
parts: []builder.Expression{left, right},
|
||||
}
|
||||
}
|
||||
|
||||
type unaryExpression struct {
|
||||
op builder.MongoOperation
|
||||
rhs builder.Expression
|
||||
}
|
||||
|
||||
func (e *unaryExpression) Build() any {
|
||||
return bson.D{{Key: string(e.op), Value: e.rhs.Build()}}
|
||||
}
|
||||
|
||||
func newUnaryExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
|
||||
return &unaryExpression{
|
||||
op: op,
|
||||
rhs: right,
|
||||
}
|
||||
}
|
||||
|
||||
type matchExpression struct {
|
||||
op builder.MongoOperation
|
||||
rhs builder.Expression
|
||||
}
|
||||
|
||||
func (e *matchExpression) Build() any {
|
||||
return bson.E{Key: string(e.op), Value: e.rhs.Build()}
|
||||
}
|
||||
|
||||
func newMatchExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
|
||||
return &matchExpression{
|
||||
op: op,
|
||||
rhs: right,
|
||||
}
|
||||
}
|
||||
|
||||
func InRef(value builder.Field) builder.Expression {
|
||||
return newMatchExpression(builder.In, NewValue(NewRefFieldImp(value).Build()))
|
||||
}
|
||||
|
||||
type inImpl struct {
|
||||
values []any
|
||||
}
|
||||
|
||||
func (e *inImpl) Build() any {
|
||||
return bson.D{{Key: string(builder.In), Value: e.values}}
|
||||
}
|
||||
|
||||
func In(values ...any) builder.Expression {
|
||||
var flattenedValues []any
|
||||
|
||||
for _, v := range values {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Slice:
|
||||
slice := reflect.ValueOf(v)
|
||||
for i := range slice.Len() {
|
||||
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
flattenedValues = append(flattenedValues, v)
|
||||
}
|
||||
}
|
||||
return &inImpl{values: flattenedValues}
|
||||
}
|
||||
71
api/pkg/db/internal/mongo/repositoryimp/builderimp/field.go
Normal file
71
api/pkg/db/internal/mongo/repositoryimp/builderimp/field.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
)
|
||||
|
||||
type FieldImp struct {
|
||||
fields []string
|
||||
}
|
||||
|
||||
func (b *FieldImp) Dot(field string) builder.Field {
|
||||
newFields := make([]string, len(b.fields), len(b.fields)+1)
|
||||
copy(newFields, b.fields)
|
||||
newFields = append(newFields, field)
|
||||
return &FieldImp{fields: newFields}
|
||||
}
|
||||
|
||||
func (b *FieldImp) CopyWith(field string) builder.Field {
|
||||
copiedFields := make([]string, 0, len(b.fields)+1)
|
||||
copiedFields = append(copiedFields, b.fields...)
|
||||
copiedFields = append(copiedFields, field)
|
||||
return &FieldImp{
|
||||
fields: copiedFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *FieldImp) Build() string {
|
||||
return strings.Join(b.fields, ".")
|
||||
}
|
||||
|
||||
func NewFieldImp(baseName string) builder.Field {
|
||||
return &FieldImp{
|
||||
fields: []string{baseName},
|
||||
}
|
||||
}
|
||||
|
||||
type RefField struct {
|
||||
imp builder.Field
|
||||
}
|
||||
|
||||
func (b *RefField) Build() string {
|
||||
return "$" + b.imp.Build()
|
||||
}
|
||||
|
||||
func (b *RefField) CopyWith(field string) builder.Field {
|
||||
return &RefField{
|
||||
imp: b.imp.CopyWith(field),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RefField) Dot(field string) builder.Field {
|
||||
return &RefField{
|
||||
imp: b.imp.Dot(field),
|
||||
}
|
||||
}
|
||||
|
||||
func NewRefFieldImp(field builder.Field) builder.Field {
|
||||
return &RefField{
|
||||
imp: field,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRootRef() builder.Field {
|
||||
return NewFieldImp("$$ROOT")
|
||||
}
|
||||
|
||||
func NewRemoveRef() builder.Field {
|
||||
return NewFieldImp("$$REMOVE")
|
||||
}
|
||||
137
api/pkg/db/internal/mongo/repositoryimp/builderimp/func.go
Normal file
137
api/pkg/db/internal/mongo/repositoryimp/builderimp/func.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type condImp struct {
|
||||
condition builder.Expression
|
||||
ifTrue any
|
||||
ifFalse any
|
||||
}
|
||||
|
||||
func (c *condImp) Build() any {
|
||||
return bson.D{
|
||||
{Key: string(builder.Cond), Value: bson.D{
|
||||
{Key: "if", Value: c.condition.Build()},
|
||||
{Key: "then", Value: c.ifTrue},
|
||||
{Key: "else", Value: c.ifFalse},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCond(condition builder.Expression, ifTrue, ifFalse any) builder.Expression {
|
||||
return &condImp{
|
||||
condition: condition,
|
||||
ifTrue: ifTrue,
|
||||
ifFalse: ifFalse,
|
||||
}
|
||||
}
|
||||
|
||||
// setUnionImp implements builder.Expression but takes only builder.Array inputs.
|
||||
type setUnionImp struct {
|
||||
inputs []builder.Expression
|
||||
}
|
||||
|
||||
// Build renders the $setUnion stage:
|
||||
//
|
||||
// { $setUnion: [ <array1>, <array2>, … ] }
|
||||
func (s *setUnionImp) Build() any {
|
||||
arr := make(bson.A, len(s.inputs))
|
||||
for i, arrayExpr := range s.inputs {
|
||||
arr[i] = arrayExpr.Build()
|
||||
}
|
||||
return bson.D{
|
||||
{Key: string(builder.SetUnion), Value: arr},
|
||||
}
|
||||
}
|
||||
|
||||
// NewSetUnion constructs a new $setUnion expression from the given Arrays.
|
||||
func NewSetUnion(arrays ...builder.Expression) builder.Expression {
|
||||
return &setUnionImp{inputs: arrays}
|
||||
}
|
||||
|
||||
type assignmentImp struct {
|
||||
field builder.Field
|
||||
expression builder.Expression
|
||||
}
|
||||
|
||||
func (a *assignmentImp) Build() bson.D {
|
||||
// Assign it to the given field name
|
||||
return bson.D{
|
||||
{Key: a.field.Build(), Value: a.expression.Build()},
|
||||
}
|
||||
}
|
||||
|
||||
// NewAssignment creates a projection assignment of the form:
|
||||
//
|
||||
// <field>: <expression>
|
||||
func NewAssignment(field builder.Field, expression builder.Expression) builder.Projection {
|
||||
return &assignmentImp{
|
||||
field: field,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
type computeImp struct {
|
||||
field builder.Field
|
||||
expression builder.Expression
|
||||
}
|
||||
|
||||
func (a *computeImp) Build() any {
|
||||
return bson.D{
|
||||
{Key: string(a.field.Build()), Value: a.expression.Build()},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCompute(field builder.Field, expression builder.Expression) builder.Expression {
|
||||
return &computeImp{
|
||||
field: field,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
func NewIfNull(expression, replacement builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.IfNull, expression, replacement)
|
||||
}
|
||||
|
||||
func NewPush(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Push, expression)
|
||||
}
|
||||
|
||||
func NewAnd(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.And, exprs...)
|
||||
}
|
||||
|
||||
func NewOr(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.Or, exprs...)
|
||||
}
|
||||
|
||||
func NewEach(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.Each, exprs...)
|
||||
}
|
||||
|
||||
func NewLt(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Lt, left, right)
|
||||
}
|
||||
|
||||
func NewNot(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Not, expression)
|
||||
}
|
||||
|
||||
func NewSum(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Sum, expression)
|
||||
}
|
||||
|
||||
func NewMin(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Min, expression)
|
||||
}
|
||||
|
||||
func First(expr builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.First, expr)
|
||||
}
|
||||
|
||||
func NewType(expr builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Type, expr)
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type groupAccumulatorImp struct {
|
||||
field builder.Field
|
||||
acc builder.Accumulator
|
||||
}
|
||||
|
||||
// NewGroupAccumulator creates a new GroupAccumulator for the given field using the specified operator and value.
|
||||
func NewGroupAccumulator(field builder.Field, acc builder.Accumulator) builder.GroupAccumulator {
|
||||
return &groupAccumulatorImp{
|
||||
field: field,
|
||||
acc: acc,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *groupAccumulatorImp) Field() builder.Field {
|
||||
return g.field
|
||||
}
|
||||
|
||||
func (g *groupAccumulatorImp) Accumulator() builder.Accumulator {
|
||||
return g.acc
|
||||
}
|
||||
|
||||
// Build returns a bson.E element for this group accumulator.
|
||||
func (g *groupAccumulatorImp) Build() bson.D {
|
||||
return bson.D{{
|
||||
Key: g.field.Build(),
|
||||
Value: g.acc.Build(),
|
||||
}}
|
||||
}
|
||||
60
api/pkg/db/internal/mongo/repositoryimp/builderimp/patch.go
Normal file
60
api/pkg/db/internal/mongo/repositoryimp/builderimp/patch.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type patchBuilder struct {
|
||||
updates bson.D
|
||||
}
|
||||
|
||||
func set(field builder.Field, value any) bson.E {
|
||||
return bson.E{Key: string(builder.Set), Value: bson.D{{Key: field.Build(), Value: value}}}
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Set(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, set(field, value))
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Inc(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Inc), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Unset(field builder.Field) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Unset), Value: bson.D{{Key: field.Build(), Value: ""}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Rename(field builder.Field, newName string) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Rename), Value: bson.D{{Key: field.Build(), Value: newName}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Push(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Push), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Pull(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Pull), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) AddToSet(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.AddToSet), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Build() bson.D {
|
||||
return append(u.updates, set(NewFieldImp(storable.UpdatedAtField), time.Now()))
|
||||
}
|
||||
|
||||
func NewPatchImp() builder.Patch {
|
||||
return &patchBuilder{updates: bson.D{}}
|
||||
}
|
||||
131
api/pkg/db/internal/mongo/repositoryimp/builderimp/pipeline.go
Normal file
131
api/pkg/db/internal/mongo/repositoryimp/builderimp/pipeline.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type unwindOpts = builder.UnwindOpts
|
||||
|
||||
// UnwindOption is the same type defined in the builder package.
|
||||
type UnwindOption = builder.UnwindOption
|
||||
|
||||
// NewUnwindOpts applies all UnwindOption's to a fresh unwindOpts.
|
||||
func NewUnwindOpts(opts ...UnwindOption) *unwindOpts {
|
||||
cfg := &unwindOpts{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
type PipelineImp struct {
|
||||
pipeline mongo.Pipeline
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Match(filter builder.Query) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, filter.BuildPipeline())
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: from},
|
||||
{Key: string(builder.MKLocalField), Value: localField.Build()},
|
||||
{Key: string(builder.MKForeignField), Value: foreignField.Build()},
|
||||
{Key: string(builder.MKAs), Value: as.Build()},
|
||||
}}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) LookupWithPipeline(
|
||||
from mservice.Type,
|
||||
nested builder.Pipeline,
|
||||
as builder.Field,
|
||||
let *map[string]builder.Field,
|
||||
) builder.Pipeline {
|
||||
lookupStage := bson.D{
|
||||
{Key: string(builder.MKFrom), Value: from},
|
||||
{Key: string(builder.MKPipeline), Value: nested.Build()},
|
||||
{Key: string(builder.MKAs), Value: as.Build()},
|
||||
}
|
||||
|
||||
// only add "let" if provided and not empty
|
||||
if let != nil && len(*let) > 0 {
|
||||
letDoc := bson.D{}
|
||||
for varName, fld := range *let {
|
||||
letDoc = append(letDoc, bson.E{Key: varName, Value: fld.Build()})
|
||||
}
|
||||
lookupStage = append(lookupStage, bson.E{Key: string(builder.MKLet), Value: letDoc})
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: lookupStage}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline {
|
||||
cfg := NewUnwindOpts(opts...)
|
||||
|
||||
var stageValue interface{}
|
||||
// if no options, shorthand
|
||||
if !cfg.PreserveNullAndEmptyArrays && cfg.IncludeArrayIndex == "" {
|
||||
stageValue = path.Build()
|
||||
} else {
|
||||
d := bson.D{{Key: string(builder.MKPath), Value: path.Build()}}
|
||||
if cfg.PreserveNullAndEmptyArrays {
|
||||
d = append(d, bson.E{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true})
|
||||
}
|
||||
if cfg.IncludeArrayIndex != "" {
|
||||
d = append(d, bson.E{Key: string(builder.MKIncludeArrayIndex), Value: cfg.IncludeArrayIndex})
|
||||
}
|
||||
stageValue = d
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Unwind), Value: stageValue}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Count(field builder.Field) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Count), Value: field.Build()}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
|
||||
groupDoc := groupBy.Build()
|
||||
for _, acc := range accumulators {
|
||||
groupDoc = append(groupDoc, acc.Build()...)
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{
|
||||
{Key: string(builder.Group), Value: groupDoc},
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Project(projections ...builder.Projection) builder.Pipeline {
|
||||
projDoc := bson.D{}
|
||||
for _, pr := range projections {
|
||||
projDoc = append(projDoc, pr.Build()...)
|
||||
}
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Project), Value: projDoc}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) ReplaceRoot(newRoot builder.Expression) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: newRoot.Build()},
|
||||
}}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Build() mongo.Pipeline {
|
||||
return b.pipeline
|
||||
}
|
||||
|
||||
func NewPipelineImp() builder.Pipeline {
|
||||
return &PipelineImp{
|
||||
pipeline: mongo.Pipeline{},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,563 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestNewPipelineImp(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
|
||||
assert.NotNil(t, pipeline)
|
||||
assert.IsType(t, &PipelineImp{}, pipeline)
|
||||
|
||||
// Build should return empty pipeline initially
|
||||
built := pipeline.Build()
|
||||
assert.NotNil(t, built)
|
||||
assert.Len(t, built, 0)
|
||||
}
|
||||
|
||||
func TestPipelineImp_Match(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockQuery := &MockQuery{
|
||||
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}},
|
||||
}
|
||||
|
||||
result := pipeline.Match(mockQuery)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}}, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Lookup(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockLocalField := &MockField{build: "localField"}
|
||||
mockForeignField := &MockField{build: "foreignField"}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
result := pipeline.Lookup(mservice.Projects, mockLocalField, mockForeignField, mockAsField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Projects},
|
||||
{Key: string(builder.MKLocalField), Value: "localField"},
|
||||
{Key: string(builder.MKForeignField), Value: "foreignField"},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithoutLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, nil)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
mockLetField := &MockField{build: "$_id"}
|
||||
|
||||
letVars := map[string]builder.Field{
|
||||
"projRef": mockLetField,
|
||||
}
|
||||
|
||||
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &letVars)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
{Key: string(builder.MKLet), Value: bson.D{{Key: "projRef", Value: "$_id"}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithEmptyLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
emptyLetVars := map[string]builder.Field{}
|
||||
|
||||
pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &emptyLetVars)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
// Should not include let field when empty
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_Simple(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
result := pipeline.Unwind(mockField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: "$array"}}
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithPreserveNullAndEmptyArrays(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption function
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, preserveOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithIncludeArrayIndex(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption function
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "arrayIndex"
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, indexOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithBothOptions(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption functions
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "arrayIndex"
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, preserveOpt, indexOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
|
||||
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Count(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "totalCount"}
|
||||
|
||||
result := pipeline.Count(mockField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Count), Value: "totalCount"}}
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Group(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockAlias := &MockAlias{
|
||||
build: bson.D{{Key: "_id", Value: "$field"}},
|
||||
field: &MockField{build: "_id"},
|
||||
}
|
||||
mockAccumulator := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
|
||||
}
|
||||
|
||||
result := pipeline.Group(mockAlias, mockAccumulator)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
|
||||
{Key: "_id", Value: "$field"},
|
||||
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Group_MultipleAccumulators(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockAlias := &MockAlias{
|
||||
build: bson.D{{Key: "_id", Value: "$field"}},
|
||||
field: &MockField{build: "_id"},
|
||||
}
|
||||
mockAccumulator1 := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
|
||||
}
|
||||
mockAccumulator2 := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}},
|
||||
}
|
||||
|
||||
pipeline.Group(mockAlias, mockAccumulator1, mockAccumulator2)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
|
||||
{Key: "_id", Value: "$field"},
|
||||
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
|
||||
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Project(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockProjection := &MockProjection{
|
||||
build: bson.D{{Key: "field1", Value: 1}},
|
||||
}
|
||||
|
||||
result := pipeline.Project(mockProjection)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "field1", Value: 1},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Project_MultipleProjections(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockProjection1 := &MockProjection{
|
||||
build: bson.D{{Key: "field1", Value: 1}},
|
||||
}
|
||||
mockProjection2 := &MockProjection{
|
||||
build: bson.D{{Key: "field2", Value: 0}},
|
||||
}
|
||||
|
||||
pipeline.Project(mockProjection1, mockProjection2)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "field1", Value: 1},
|
||||
{Key: "field2", Value: 0},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ChainedOperations(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
|
||||
// Create mocks
|
||||
mockQuery := &MockQuery{
|
||||
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}},
|
||||
}
|
||||
mockLocalField := &MockField{build: "userId"}
|
||||
mockForeignField := &MockField{build: "_id"}
|
||||
mockAsField := &MockField{build: "user"}
|
||||
mockUnwindField := &MockField{build: "$user"}
|
||||
mockProjection := &MockProjection{
|
||||
build: bson.D{{Key: "name", Value: "$user.name"}},
|
||||
}
|
||||
|
||||
// Chain operations
|
||||
result := pipeline.
|
||||
Match(mockQuery).
|
||||
Lookup(mservice.Accounts, mockLocalField, mockForeignField, mockAsField).
|
||||
Unwind(mockUnwindField).
|
||||
Project(mockProjection)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 4)
|
||||
|
||||
// Verify each stage
|
||||
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}}, built[0])
|
||||
|
||||
expectedLookup := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Accounts},
|
||||
{Key: string(builder.MKLocalField), Value: "userId"},
|
||||
{Key: string(builder.MKForeignField), Value: "_id"},
|
||||
{Key: string(builder.MKAs), Value: "user"},
|
||||
}}}
|
||||
assert.Equal(t, expectedLookup, built[1])
|
||||
|
||||
assert.Equal(t, bson.D{{Key: string(builder.Unwind), Value: "$user"}}, built[2])
|
||||
|
||||
expectedProject := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "name", Value: "$user.name"},
|
||||
}}}
|
||||
assert.Equal(t, expectedProject, built[3])
|
||||
}
|
||||
|
||||
func TestNewUnwindOpts(t *testing.T) {
|
||||
t.Run("NoOptions", func(t *testing.T) {
|
||||
opts := NewUnwindOpts()
|
||||
|
||||
assert.NotNil(t, opts)
|
||||
assert.False(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Empty(t, opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithPreserveOption", func(t *testing.T) {
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(preserveOpt)
|
||||
|
||||
assert.True(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Empty(t, opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithIndexOption", func(t *testing.T) {
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "index"
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(indexOpt)
|
||||
|
||||
assert.False(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Equal(t, "index", opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithBothOptions", func(t *testing.T) {
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "index"
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(preserveOpt, indexOpt)
|
||||
|
||||
assert.True(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Equal(t, "index", opts.IncludeArrayIndex)
|
||||
})
|
||||
}
|
||||
|
||||
// Mock implementations for testing
|
||||
|
||||
type MockQuery struct {
|
||||
buildPipeline bson.D
|
||||
}
|
||||
|
||||
func (m *MockQuery) And(filters ...builder.Query) builder.Query { return m }
|
||||
func (m *MockQuery) Or(filters ...builder.Query) builder.Query { return m }
|
||||
func (m *MockQuery) Filter(field builder.Field, value any) builder.Query { return m }
|
||||
func (m *MockQuery) Expression(value builder.Expression) builder.Query { return m }
|
||||
func (m *MockQuery) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
|
||||
return m
|
||||
}
|
||||
func (m *MockQuery) RegEx(field builder.Field, pattern, options string) builder.Query { return m }
|
||||
func (m *MockQuery) In(field builder.Field, values ...any) builder.Query { return m }
|
||||
func (m *MockQuery) NotIn(field builder.Field, values ...any) builder.Query { return m }
|
||||
func (m *MockQuery) Sort(field builder.Field, ascending bool) builder.Query { return m }
|
||||
func (m *MockQuery) Limit(limit *int64) builder.Query { return m }
|
||||
func (m *MockQuery) Offset(offset *int64) builder.Query { return m }
|
||||
func (m *MockQuery) Archived(isArchived *bool) builder.Query { return m }
|
||||
func (m *MockQuery) BuildPipeline() bson.D { return m.buildPipeline }
|
||||
func (m *MockQuery) BuildQuery() bson.D { return bson.D{} }
|
||||
func (m *MockQuery) BuildOptions() *options.FindOptions { return &options.FindOptions{} }
|
||||
|
||||
type MockField struct {
|
||||
build string
|
||||
}
|
||||
|
||||
func (m *MockField) Dot(field string) builder.Field { return &MockField{build: m.build + "." + field} }
|
||||
func (m *MockField) CopyWith(field string) builder.Field { return &MockField{build: field} }
|
||||
func (m *MockField) Build() string { return m.build }
|
||||
|
||||
type MockPipeline struct {
|
||||
build mongo.Pipeline
|
||||
}
|
||||
|
||||
func (m *MockPipeline) Match(filter builder.Query) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) LookupWithPipeline(from mservice.Type, pipeline builder.Pipeline, as builder.Field, let *map[string]builder.Field) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Count(field builder.Field) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) Project(projections ...builder.Projection) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) ReplaceRoot(newRoot builder.Expression) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Build() mongo.Pipeline { return m.build }
|
||||
|
||||
type MockAlias struct {
|
||||
build bson.D
|
||||
field builder.Field
|
||||
}
|
||||
|
||||
func (m *MockAlias) Field() builder.Field { return m.field }
|
||||
func (m *MockAlias) Build() bson.D { return m.build }
|
||||
|
||||
type MockGroupAccumulator struct {
|
||||
build bson.D
|
||||
}
|
||||
|
||||
func (m *MockGroupAccumulator) Build() bson.D { return m.build }
|
||||
|
||||
type MockProjection struct {
|
||||
build bson.D
|
||||
}
|
||||
|
||||
func (m *MockProjection) Build() bson.D { return m.build }
|
||||
|
||||
func TestPipelineImp_ReplaceRoot(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockExpr := &MockExpression{build: "$newRoot"}
|
||||
|
||||
result := pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: "$newRoot"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ReplaceRoot_WithNestedField(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockExpr := &MockExpression{build: "$document.data"}
|
||||
|
||||
pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: "$document.data"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ReplaceRoot_WithExpression(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
// Mock a complex expression like { $mergeObjects: [...] }
|
||||
mockExpr := &MockExpression{build: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}}
|
||||
|
||||
pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
type MockExpression struct {
|
||||
build any
|
||||
}
|
||||
|
||||
func (m *MockExpression) Build() any { return m.build }
|
||||
@@ -0,0 +1,97 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
// projectionExprImp is a concrete implementation of builder.Projection
|
||||
// that projects a field using a custom expression.
|
||||
type projectionExprImp struct {
|
||||
expr builder.Expression // The expression for this projection.
|
||||
field builder.Field // The field name for the projected field.
|
||||
}
|
||||
|
||||
// Field returns the field being projected.
|
||||
func (p *projectionExprImp) Field() builder.Field {
|
||||
return p.field
|
||||
}
|
||||
|
||||
// Expression returns the expression for the projection.
|
||||
func (p *projectionExprImp) Expression() builder.Expression {
|
||||
return p.expr
|
||||
}
|
||||
|
||||
// Build returns the built expression. If no expression is provided, returns 1.
|
||||
func (p *projectionExprImp) Build() bson.D {
|
||||
if p.expr == nil {
|
||||
return bson.D{{Key: p.field.Build(), Value: 1}}
|
||||
}
|
||||
return bson.D{{Key: p.field.Build(), Value: p.expr.Build()}}
|
||||
}
|
||||
|
||||
// NewProjectionExpr creates a new Projection for a given field and expression.
|
||||
func NewProjectionExpr(field builder.Field, expr builder.Expression) builder.Projection {
|
||||
return &projectionExprImp{field: field, expr: expr}
|
||||
}
|
||||
|
||||
// aliasProjectionImp is a concrete implementation of builder.Projection
|
||||
// that projects an alias (renaming a field or expression).
|
||||
type aliasProjectionImp struct {
|
||||
alias builder.Alias // The alias for this projection.
|
||||
}
|
||||
|
||||
// Field returns the field being projected (via the alias).
|
||||
func (p *aliasProjectionImp) Field() builder.Field {
|
||||
return p.alias.Field()
|
||||
}
|
||||
|
||||
// Expression returns no additional expression for an alias projection.
|
||||
func (p *aliasProjectionImp) Expression() builder.Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build returns the built alias expression.
|
||||
func (p *aliasProjectionImp) Build() bson.D {
|
||||
return p.alias.Build()
|
||||
}
|
||||
|
||||
// NewAliasProjection creates a new Projection that renames or wraps an existing field or expression.
|
||||
func NewAliasProjection(alias builder.Alias) builder.Projection {
|
||||
return &aliasProjectionImp{alias: alias}
|
||||
}
|
||||
|
||||
// sinkProjectionImp is a simple include/exclude projection (0 or 1).
|
||||
type sinkProjectionImp struct {
|
||||
field builder.Field // The field name for the projected field.
|
||||
val int // 1 to include, 0 to exclude.
|
||||
}
|
||||
|
||||
// Expression returns no expression for a sink projection.
|
||||
func (p *sinkProjectionImp) Expression() builder.Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build returns the include/exclude projection.
|
||||
func (p *sinkProjectionImp) Build() bson.D {
|
||||
return bson.D{{Key: p.field.Build(), Value: p.val}}
|
||||
}
|
||||
|
||||
// NewSinkProjection creates a new Projection that includes (true) or excludes (false) a field.
|
||||
func NewSinkProjection(field builder.Field, include bool) builder.Projection {
|
||||
val := 0
|
||||
if include {
|
||||
val = 1
|
||||
}
|
||||
return &sinkProjectionImp{field: field, val: val}
|
||||
}
|
||||
|
||||
// IncludeField returns a projection including the given field.
|
||||
func IncludeField(field builder.Field) builder.Projection {
|
||||
return NewSinkProjection(field, true)
|
||||
}
|
||||
|
||||
// ExcludeField returns a projection excluding the given field.
|
||||
func ExcludeField(field builder.Field) builder.Projection {
|
||||
return NewSinkProjection(field, false)
|
||||
}
|
||||
156
api/pkg/db/internal/mongo/repositoryimp/builderimp/query.go
Normal file
156
api/pkg/db/internal/mongo/repositoryimp/builderimp/query.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type QueryImp struct {
|
||||
filter bson.D
|
||||
sort bson.D
|
||||
limit *int64
|
||||
offset *int64
|
||||
}
|
||||
|
||||
func (b *QueryImp) Filter(field builder.Field, value any) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: value})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) And(filters ...builder.Query) builder.Query {
|
||||
andFilters := bson.A{}
|
||||
for _, f := range filters {
|
||||
andFilters = append(andFilters, f.BuildQuery())
|
||||
}
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.And), Value: andFilters})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Or(filters ...builder.Query) builder.Query {
|
||||
orFilters := bson.A{}
|
||||
for _, f := range filters {
|
||||
orFilters = append(orFilters, f.BuildQuery())
|
||||
}
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.Or), Value: orFilters})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(operator): value}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Expression(value builder.Expression) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.Expr), Value: value.Build()})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) RegEx(field builder.Field, pattern, options string) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: primitive.Regex{Pattern: pattern, Options: options}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) opIn(field builder.Field, op builder.MongoOperation, values ...any) builder.Query {
|
||||
var flattenedValues []any
|
||||
|
||||
for _, v := range values {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Slice:
|
||||
slice := reflect.ValueOf(v)
|
||||
for i := range slice.Len() {
|
||||
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
flattenedValues = append(flattenedValues, v)
|
||||
}
|
||||
}
|
||||
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(op): flattenedValues}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) NotIn(field builder.Field, values ...any) builder.Query {
|
||||
return b.opIn(field, builder.NotIn, values...)
|
||||
}
|
||||
|
||||
func (b *QueryImp) In(field builder.Field, values ...any) builder.Query {
|
||||
return b.opIn(field, builder.In, values...)
|
||||
}
|
||||
|
||||
func (b *QueryImp) Archived(isArchived *bool) builder.Query {
|
||||
if isArchived == nil {
|
||||
return b
|
||||
}
|
||||
return b.And(NewQueryImp().Filter(NewFieldImp(storable.IsArchivedField), *isArchived))
|
||||
}
|
||||
|
||||
func (b *QueryImp) Sort(field builder.Field, ascending bool) builder.Query {
|
||||
order := 1
|
||||
if !ascending {
|
||||
order = -1
|
||||
}
|
||||
b.sort = append(b.sort, bson.E{Key: field.Build(), Value: order})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildPipeline() bson.D {
|
||||
query := bson.D{}
|
||||
|
||||
if len(b.filter) > 0 {
|
||||
query = append(query, bson.E{Key: string(builder.Match), Value: b.filter})
|
||||
}
|
||||
|
||||
if len(b.sort) > 0 {
|
||||
query = append(query, bson.E{Key: string(builder.Sort), Value: b.sort})
|
||||
}
|
||||
|
||||
if b.limit != nil {
|
||||
query = append(query, bson.E{Key: string(builder.Limit), Value: *b.limit})
|
||||
}
|
||||
|
||||
if b.offset != nil {
|
||||
query = append(query, bson.E{Key: string(builder.Skip), Value: *b.offset})
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildQuery() bson.D {
|
||||
return b.filter
|
||||
}
|
||||
|
||||
func (b *QueryImp) Limit(limit *int64) builder.Query {
|
||||
b.limit = limit
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Offset(offset *int64) builder.Query {
|
||||
b.offset = offset
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildOptions() *options.FindOptions {
|
||||
opts := options.Find()
|
||||
if b.limit != nil {
|
||||
opts.SetLimit(*b.limit)
|
||||
}
|
||||
if b.offset != nil {
|
||||
opts.SetSkip(*b.offset)
|
||||
}
|
||||
if len(b.sort) > 0 {
|
||||
opts.SetSort(b.sort)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func NewQueryImp() builder.Query {
|
||||
return &QueryImp{
|
||||
filter: bson.D{},
|
||||
sort: bson.D{},
|
||||
}
|
||||
}
|
||||
17
api/pkg/db/internal/mongo/repositoryimp/builderimp/value.go
Normal file
17
api/pkg/db/internal/mongo/repositoryimp/builderimp/value.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
)
|
||||
|
||||
type valueImp struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func (v *valueImp) Build() any {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func NewValue(value any) builder.Value {
|
||||
return &valueImp{value: value}
|
||||
}
|
||||
50
api/pkg/db/internal/mongo/repositoryimp/index.go
Normal file
50
api/pkg/db/internal/mongo/repositoryimp/index.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repositoryimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func (r *MongoRepository) CreateIndex(def *ri.Definition) error {
|
||||
if r.collection == nil {
|
||||
return merrors.NoData("data collection is not set")
|
||||
}
|
||||
if len(def.Keys) == 0 {
|
||||
return merrors.InvalidArgument("Index definition has no keys")
|
||||
}
|
||||
|
||||
// ----- build BSON keys --------------------------------------------------
|
||||
keys := bson.D{}
|
||||
for _, k := range def.Keys {
|
||||
var value any
|
||||
switch {
|
||||
case k.Type != "":
|
||||
value = k.Type // text, 2dsphere, …
|
||||
case k.Sort == ri.Desc:
|
||||
value = int8(-1)
|
||||
default:
|
||||
value = int8(1) // default to Asc
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k.Field, Value: value})
|
||||
}
|
||||
|
||||
opts := options.Index().
|
||||
SetUnique(def.Unique)
|
||||
if def.TTL != nil {
|
||||
opts.SetExpireAfterSeconds(*def.TTL)
|
||||
}
|
||||
if def.Name != "" {
|
||||
opts.SetName(def.Name)
|
||||
}
|
||||
|
||||
_, err := r.collection.Indexes().CreateOne(
|
||||
context.Background(),
|
||||
mongo.IndexModel{Keys: keys, Options: opts},
|
||||
)
|
||||
return err
|
||||
}
|
||||
250
api/pkg/db/internal/mongo/repositoryimp/repository.go
Normal file
250
api/pkg/db/internal/mongo/repositoryimp/repository.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package repositoryimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
rd "github.com/tech/sendico/pkg/db/repository/decoder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type MongoRepository struct {
|
||||
collectionName string
|
||||
collection *mongo.Collection
|
||||
}
|
||||
|
||||
func idFilter(id primitive.ObjectID) bson.D {
|
||||
return bson.D{
|
||||
{Key: storable.IDField, Value: id},
|
||||
}
|
||||
}
|
||||
|
||||
func NewMongoRepository(db *mongo.Database, collection string) *MongoRepository {
|
||||
return &MongoRepository{
|
||||
collectionName: collection,
|
||||
collection: db.Collection(collection),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Collection() string {
|
||||
return r.collectionName
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Insert(ctx context.Context, obj storable.Storable, getFilter builder.Query) error {
|
||||
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
}
|
||||
obj.Update()
|
||||
_, err := r.collection.InsertOne(ctx, obj)
|
||||
if mongo.IsDuplicateKeyError(err) {
|
||||
if getFilter != nil {
|
||||
if err = r.FindOneByFilter(ctx, getFilter, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return merrors.DataConflict("duplicate_key")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) InsertMany(ctx context.Context, objects []storable.Storable) error {
|
||||
if len(objects) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
docs := make([]interface{}, len(objects))
|
||||
for i, obj := range objects {
|
||||
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
}
|
||||
obj.Update()
|
||||
docs[i] = obj
|
||||
}
|
||||
|
||||
_, err := r.collection.InsertMany(ctx, docs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) findOneByFilterImp(ctx context.Context, filter bson.D, errMessage string, result storable.Storable) error {
|
||||
err := r.collection.FindOne(ctx, filter).Decode(result)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData(errMessage)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Get(ctx context.Context, id primitive.ObjectID, result storable.Storable) error {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("zero id provided while fetching " + result.Collection())
|
||||
}
|
||||
return r.findOneByFilterImp(ctx, idFilter(id), fmt.Sprintf("%s with ID = %s not found", result.Collection(), id.Hex()), result)
|
||||
}
|
||||
|
||||
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
|
||||
|
||||
func (r *MongoRepository) executeQuery(ctx context.Context, queryFunc QueryFunc, decoder rd.DecodingFunc) error {
|
||||
cursor, err := queryFunc(ctx, r.collection)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData("no_items_in_array")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
if err = decoder(cursor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rd.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Aggregate(ctx, pipeline.Build())
|
||||
}
|
||||
return r.executeQuery(ctx, queryFunc, decoder)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) FindManyByFilter(ctx context.Context, query builder.Query, decoder rd.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Find(ctx, query.BuildQuery(), query.BuildOptions())
|
||||
}
|
||||
return r.executeQuery(ctx, queryFunc, decoder)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) FindOneByFilter(ctx context.Context, query builder.Query, result storable.Storable) error {
|
||||
return r.findOneByFilterImp(ctx, query.BuildQuery(), result.Collection()+" not found by filter", result)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Update(ctx context.Context, obj storable.Storable) error {
|
||||
obj.Update()
|
||||
return r.collection.FindOneAndReplace(ctx, idFilter(*obj.GetID()), obj).Err()
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Patch(ctx context.Context, id primitive.ObjectID, patch builder.Patch) error {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("zero id provided while patching")
|
||||
}
|
||||
_, err := r.collection.UpdateByID(ctx, id, patch.Build())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) PatchMany(ctx context.Context, query builder.Query, patch builder.Patch) (int, error) {
|
||||
result, err := r.collection.UpdateMany(ctx, query.BuildQuery(), patch.Build())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(result.ModifiedCount), nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{storable.IDField: 1})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var ids []primitive.ObjectID
|
||||
for cursor.Next(ctx) {
|
||||
var doc struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
}
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, doc.ID)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{
|
||||
storable.IDField: 1,
|
||||
storable.PermissionRefField: 1,
|
||||
storable.OrganizationRefField: 1,
|
||||
})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
result := make([]model.PermissionBoundStorable, 0)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var doc model.PermissionBound
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, &doc)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListAccountBound(ctx context.Context, query builder.Query) ([]model.AccountBoundStorable, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{
|
||||
storable.IDField: 1,
|
||||
model.AccountRefField: 1,
|
||||
model.OrganizationRefField: 1,
|
||||
})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
result := make([]model.AccountBoundStorable, 0)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var doc model.AccountBoundBase
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, &doc)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Delete(ctx context.Context, id primitive.ObjectID) error {
|
||||
_, err := r.collection.DeleteOne(ctx, idFilter(id))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) DeleteMany(ctx context.Context, query builder.Query) error {
|
||||
_, err := r.collection.DeleteMany(ctx, query.BuildQuery())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Name() string {
|
||||
return r.collection.Name()
|
||||
}
|
||||
@@ -0,0 +1,577 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_Insert(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Insert_WithoutID", func(t *testing.T) {
|
||||
testObj := &TestObject{Name: "testInsert"}
|
||||
// ID should be nil/zero initially
|
||||
assert.True(t, testObj.GetID().IsZero())
|
||||
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should be assigned after insert
|
||||
assert.False(t, testObj.GetID().IsZero())
|
||||
assert.NotEmpty(t, testObj.CreatedAt)
|
||||
assert.NotEmpty(t, testObj.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("Insert_WithExistingID", func(t *testing.T) {
|
||||
existingID := primitive.NewObjectID()
|
||||
testObj := &TestObject{Name: "testInsertWithID"}
|
||||
testObj.SetID(existingID)
|
||||
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should remain the same
|
||||
assert.Equal(t, existingID, *testObj.GetID())
|
||||
})
|
||||
|
||||
t.Run("Insert_DuplicateKey", func(t *testing.T) {
|
||||
// Insert first object
|
||||
testObj1 := &TestObject{Name: "duplicate"}
|
||||
err := repository.Insert(ctx, testObj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert object with same ID
|
||||
testObj2 := &TestObject{Name: "duplicate2"}
|
||||
testObj2.SetID(*testObj1.GetID())
|
||||
|
||||
err = repository.Insert(ctx, testObj2, nil)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
|
||||
})
|
||||
|
||||
t.Run("Insert_DuplicateKeyWithGetFilter", func(t *testing.T) {
|
||||
// Insert first object
|
||||
testObj1 := &TestObject{Name: "duplicateWithFilter"}
|
||||
err := repository.Insert(ctx, testObj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert object with same ID, but with getFilter
|
||||
testObj2 := &TestObject{Name: "duplicateWithFilter2"}
|
||||
testObj2.SetID(*testObj1.GetID())
|
||||
|
||||
getFilter := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("_id"), builder.Eq, *testObj1.GetID())
|
||||
|
||||
err = repository.Insert(ctx, testObj2, getFilter)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
|
||||
|
||||
// But testObj2 should be populated with the existing object data
|
||||
assert.Equal(t, testObj1.Name, testObj2.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Update(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Update_ExistingObject", func(t *testing.T) {
|
||||
// Insert object first
|
||||
testObj := &TestObject{Name: "originalName"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
originalUpdatedAt := testObj.UpdatedAt
|
||||
|
||||
// Update the object
|
||||
testObj.Name = "updatedName"
|
||||
time.Sleep(10 * time.Millisecond) // Ensure time difference
|
||||
|
||||
err = repository.Update(ctx, testObj)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the object was updated
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, *testObj.GetID(), result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updatedName", result.Name)
|
||||
assert.True(t, result.UpdatedAt.After(originalUpdatedAt))
|
||||
})
|
||||
|
||||
t.Run("Update_NonExistentObject", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
testObj := &TestObject{Name: "nonExistent"}
|
||||
testObj.SetID(nonExistentID)
|
||||
|
||||
err := repository.Update(ctx, testObj)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, mongo.ErrNoDocuments))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Delete(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Delete_ExistingObject", func(t *testing.T) {
|
||||
// Insert object first
|
||||
testObj := &TestObject{Name: "toDelete"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the object
|
||||
err = repository.Delete(ctx, *testObj.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the object was deleted
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, *testObj.GetID(), result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Delete_NonExistentObject", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
|
||||
err := repository.Delete(ctx, nonExistentID)
|
||||
// Delete should not return error even if object doesn't exist
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_FindOneByFilter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("FindOneByFilter_MatchingFilter", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "findMe"},
|
||||
{Name: "dontFindMe"},
|
||||
{Name: "findMeToo"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Find by filter
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "findMe")
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.FindOneByFilter(ctx, query, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "findMe", result.Name)
|
||||
})
|
||||
|
||||
t.Run("FindOneByFilter_NoMatch", func(t *testing.T) {
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.FindOneByFilter(ctx, query, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_FindManyByFilter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("FindManyByFilter_MultipleResults", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "findMany1"},
|
||||
{Name: "findMany2"},
|
||||
{Name: "dontFind"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Find objects with names starting with "findMany"
|
||||
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^findMany", "")
|
||||
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := repository.FindManyByFilter(ctx, query, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
names := make([]string, len(results))
|
||||
for i, obj := range results {
|
||||
names[i] = obj.Name
|
||||
}
|
||||
assert.Contains(t, names, "findMany1")
|
||||
assert.Contains(t, names, "findMany2")
|
||||
})
|
||||
|
||||
t.Run("FindManyByFilter_NoResults", func(t *testing.T) {
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentPattern")
|
||||
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := repository.FindManyByFilter(ctx, query, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_DeleteMany(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("DeleteMany_MultipleDocuments", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "deleteMany1"},
|
||||
{Name: "deleteMany2"},
|
||||
{Name: "keepMe"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Delete objects with names starting with "deleteMany"
|
||||
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^deleteMany", "")
|
||||
|
||||
err := repository.DeleteMany(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletions
|
||||
queryAll := builderimp.NewQueryImp()
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = repository.FindManyByFilter(ctx, queryAll, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "keepMe", results[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Name(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "mycollection")
|
||||
|
||||
t.Run("Name_ReturnsCollectionName", func(t *testing.T) {
|
||||
name := repository.Name()
|
||||
assert.Equal(t, "mycollection", name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_ListPermissionBound(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("ListPermissionBound_WithData", func(t *testing.T) {
|
||||
// Insert test objects with permission bound data
|
||||
orgID := primitive.NewObjectID()
|
||||
|
||||
// Insert documents directly with permission bound fields
|
||||
_, err := db.Collection("testcollection").InsertMany(ctx, []interface{}{
|
||||
bson.M{
|
||||
"_id": primitive.NewObjectID(),
|
||||
"organizationRef": orgID,
|
||||
"permissionRef": primitive.NewObjectID(),
|
||||
},
|
||||
bson.M{
|
||||
"_id": primitive.NewObjectID(),
|
||||
"organizationRef": orgID,
|
||||
"permissionRef": primitive.NewObjectID(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query for permission bound objects
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, orgID)
|
||||
|
||||
results, err := repository.ListPermissionBound(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, orgID, result.GetOrganizationRef())
|
||||
assert.NotNil(t, result.GetPermissionRef())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListPermissionBound_EmptyResult", func(t *testing.T) {
|
||||
nonExistentOrgID := primitive.NewObjectID()
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, nonExistentOrgID)
|
||||
|
||||
results, err := repository.ListPermissionBound(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_UpdateTimestamp(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Update_Should_Update_Timestamp", func(t *testing.T) {
|
||||
// Create test object
|
||||
obj := &TestObject{
|
||||
Name: "Test Object",
|
||||
}
|
||||
|
||||
// Set ID and initial timestamps
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
originalCreatedAt := obj.CreatedAt
|
||||
originalUpdatedAt := obj.UpdatedAt
|
||||
|
||||
// Insert the object
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait a moment to ensure timestamp difference
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Update the object
|
||||
obj.Name = "Updated Object"
|
||||
err = repository.Update(ctx, obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps
|
||||
assert.Equal(t, originalCreatedAt, obj.CreatedAt, "CreatedAt should not change")
|
||||
assert.True(t, obj.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated")
|
||||
|
||||
// Verify the object was actually updated in the database
|
||||
var retrieved TestObject
|
||||
err = repository.Get(ctx, *obj.GetID(), &retrieved)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Updated Object", retrieved.Name, "Name should be updated")
|
||||
assert.WithinDuration(t, originalCreatedAt, retrieved.CreatedAt, time.Second, "CreatedAt should not change in DB")
|
||||
assert.True(t, retrieved.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated in DB")
|
||||
assert.WithinDuration(t, obj.UpdatedAt, retrieved.UpdatedAt, time.Second, "UpdatedAt should match between object and DB")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_InsertMany(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("InsertMany_Success", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Name: "test2"},
|
||||
&TestObject{Name: "test3"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all objects were inserted and have IDs
|
||||
for _, obj := range objects {
|
||||
assert.NotNil(t, obj.GetID())
|
||||
assert.False(t, obj.GetID().IsZero())
|
||||
|
||||
// Verify we can retrieve each object
|
||||
result := &TestObject{}
|
||||
err := repository.Get(ctx, *obj.GetID(), result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, obj.(*TestObject).Name, result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertMany_EmptySlice", func(t *testing.T) {
|
||||
objects := []storable.Storable{}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("InsertMany_WithExistingIDs", func(t *testing.T) {
|
||||
id1 := primitive.NewObjectID()
|
||||
id2 := primitive.NewObjectID()
|
||||
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Base: storable.Base{ID: id1}, Name: "preassigned1"},
|
||||
&TestObject{Base: storable.Base{ID: id2}, Name: "preassigned2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify objects were inserted with pre-assigned IDs
|
||||
result1 := &TestObject{}
|
||||
err = repository.Get(ctx, id1, result1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preassigned1", result1.Name)
|
||||
|
||||
result2 := &TestObject{}
|
||||
err = repository.Get(ctx, id2, result2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preassigned2", result2.Name)
|
||||
})
|
||||
|
||||
t.Run("InsertMany_MixedTypes", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&AnotherObject{Description: "desc1"},
|
||||
&TestObject{Name: "test2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all objects were inserted
|
||||
for _, obj := range objects {
|
||||
assert.NotNil(t, obj.GetID())
|
||||
assert.False(t, obj.GetID().IsZero())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertMany_DuplicateKey", func(t *testing.T) {
|
||||
id := primitive.NewObjectID()
|
||||
|
||||
// Insert first object
|
||||
obj1 := &TestObject{Base: storable.Base{ID: id}, Name: "original"}
|
||||
err := repository.Insert(ctx, obj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert multiple objects including one with duplicate ID
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Base: storable.Base{ID: id}, Name: "duplicate"},
|
||||
}
|
||||
|
||||
err = repository.InsertMany(ctx, objects)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, mongo.IsDuplicateKeyError(err))
|
||||
})
|
||||
|
||||
t.Run("InsertMany_UpdateTimestamps", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Name: "test2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps were set
|
||||
for _, obj := range objects {
|
||||
testObj := obj.(*TestObject)
|
||||
assert.NotZero(t, testObj.CreatedAt)
|
||||
assert.NotZero(t, testObj.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
233
api/pkg/db/internal/mongo/repositoryimp/repository_patch_test.go
Normal file
233
api/pkg/db/internal/mongo/repositoryimp/repository_patch_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_PatchOperations(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repo := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Patch_SingleDocument", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "old"}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
original := obj.UpdatedAt
|
||||
|
||||
patch := repository.Patch().Set(repository.Field("name"), "new")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new", result.Name)
|
||||
assert.True(t, result.UpdatedAt.After(original))
|
||||
})
|
||||
|
||||
t.Run("PatchMany_MultipleDocuments", func(t *testing.T) {
|
||||
objs := []*TestObject{{Name: "match"}, {Name: "match"}, {Name: "other"}}
|
||||
for _, o := range objs {
|
||||
err := repo.Insert(ctx, o, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
query := repository.Query().Comparison(repository.Field("name"), builder.Eq, "match")
|
||||
patch := repository.Patch().Set(repository.Field("name"), "patched")
|
||||
modified, err := repo.PatchMany(ctx, query, patch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, modified)
|
||||
|
||||
verify := repository.Query().Comparison(repository.Field("name"), builder.Eq, "patched")
|
||||
var results []TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, obj)
|
||||
return nil
|
||||
}
|
||||
err = repo.FindManyByFilter(ctx, verify, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
})
|
||||
|
||||
t.Run("Patch_PushArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2", "tag3"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag3"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_AddToSetArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add new tag
|
||||
patch := repository.Patch().AddToSet(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
|
||||
// Try to add duplicate tag - should not add
|
||||
patch = repository.Patch().AddToSet(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PushToEmptyArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullFromEmptyArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullNonExistentElement", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_ChainedArrayOperations", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: MongoDB doesn't allow multiple operations on the same array field in a single update
|
||||
// This test demonstrates that chained array operations on the same field will fail
|
||||
patch := repository.Patch().
|
||||
Push(repository.Field("tags"), "tag2").
|
||||
AddToSet(repository.Field("tags"), "tag3").
|
||||
Pull(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.Error(t, err) // This should fail due to MongoDB's limitation
|
||||
assert.Contains(t, err.Error(), "conflict")
|
||||
})
|
||||
|
||||
t.Run("PatchMany_ArrayOperations", func(t *testing.T) {
|
||||
objs := []*TestObject{
|
||||
{Name: "obj1", Tags: []string{"tag1"}},
|
||||
{Name: "obj2", Tags: []string{"tag2"}},
|
||||
{Name: "obj3", Tags: []string{"tag3"}},
|
||||
}
|
||||
for _, o := range objs {
|
||||
err := repo.Insert(ctx, o, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
query := repository.Query().Comparison(repository.Field("name"), builder.In, []string{"obj1", "obj2"})
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "common")
|
||||
modified, err := repo.PatchMany(ctx, query, patch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, modified)
|
||||
|
||||
// Verify the changes
|
||||
for _, name := range []string{"obj1", "obj2"} {
|
||||
var result TestObject
|
||||
err = repo.FindOneByFilter(ctx, repository.Query().Comparison(repository.Field("name"), builder.Eq, name), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result.Tags, "common")
|
||||
}
|
||||
})
|
||||
}
|
||||
188
api/pkg/db/internal/mongo/repositoryimp/repository_test.go
Normal file
188
api/pkg/db/internal/mongo/repositoryimp/repository_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type TestObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
Name string `bson:"name"`
|
||||
Tags []string `bson:"tags"`
|
||||
}
|
||||
|
||||
func (t *TestObject) Collection() string {
|
||||
return "testObject"
|
||||
}
|
||||
|
||||
type AnotherObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
Description string `bson:"description"`
|
||||
}
|
||||
|
||||
func (a *AnotherObject) Collection() string {
|
||||
return "anotherObject"
|
||||
}
|
||||
|
||||
func terminate(ctx context.Context, t *testing.T, container *mongodb.MongoDBContainer) {
|
||||
err := container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to terminate MongoDB container")
|
||||
}
|
||||
|
||||
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
|
||||
err := client.Disconnect(ctx)
|
||||
require.NoError(t, err, "failed to disconnect from MongoDB")
|
||||
}
|
||||
|
||||
func TestMongoRepository_Get(t *testing.T) {
|
||||
// Use a context with timeout, so if container spinning or DB ops hang,
|
||||
// the test won't run indefinitely.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Get_Success", func(t *testing.T) {
|
||||
testObj := &TestObject{Name: "testName"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, testObj.ID, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testObj.Name, result.Name)
|
||||
assert.Equal(t, testObj.ID, result.ID)
|
||||
})
|
||||
|
||||
t.Run("Get_NotFound", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.Get(ctx, nonExistentID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Get_InvalidID", func(t *testing.T) {
|
||||
invalidID := primitive.ObjectID{} // zero value
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.Get(ctx, invalidID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrInvalidArg))
|
||||
})
|
||||
|
||||
t.Run("Get_DifferentTypes", func(t *testing.T) {
|
||||
anotherObj := &AnotherObject{Description: "testDescription"}
|
||||
err := repository.Insert(ctx, anotherObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := &AnotherObject{}
|
||||
err = repository.Get(ctx, anotherObj.ID, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, anotherObj.Description, result.Description)
|
||||
assert.Equal(t, anotherObj.ID, result.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_ListIDs(t *testing.T) {
|
||||
// Use a context with timeout, so if container spinning or DB ops hang,
|
||||
// the test won't run indefinitely.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("ListIDs_Success", func(t *testing.T) {
|
||||
// Insert test data
|
||||
testObjs := []*TestObject{
|
||||
{Name: "testName1"},
|
||||
{Name: "testName2"},
|
||||
{Name: "testName3"},
|
||||
}
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Define a query to match all objects
|
||||
query := builderimp.NewQueryImp()
|
||||
|
||||
// Call ListIDs
|
||||
ids, err := repository.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert the IDs are correct
|
||||
require.Len(t, ids, len(testObjs))
|
||||
for _, obj := range testObjs {
|
||||
assert.Contains(t, ids, obj.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListIDs_EmptyResult", func(t *testing.T) {
|
||||
// Define a query that matches no objects
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
|
||||
|
||||
// Call ListIDs
|
||||
ids, err := repository.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert no IDs are returned
|
||||
assert.Empty(t, ids)
|
||||
})
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/rolesdb/db.go
Normal file
21
api/pkg/db/internal/mongo/rolesdb/db.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type RolesDB struct {
|
||||
template.DBImp[*model.RoleDescription]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*RolesDB, error) {
|
||||
p := &RolesDB{
|
||||
DBImp: *template.Create[*model.RoleDescription](logger, mservice.Roles, db),
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
15
api/pkg/db/internal/mongo/rolesdb/list.go
Normal file
15
api/pkg/db/internal/mongo/rolesdb/list.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RolesDB) List(ctx context.Context, organizationRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.RoleDescription, error) {
|
||||
filter := repository.OrgFilter(organizationRef)
|
||||
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, cursor, db.Repository)
|
||||
}
|
||||
15
api/pkg/db/internal/mongo/rolesdb/roles.go
Normal file
15
api/pkg/db/internal/mongo/rolesdb/roles.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RolesDB) Roles(ctx context.Context, refs []primitive.ObjectID) ([]model.RoleDescription, error) {
|
||||
filter := repository.Query().In(repository.IDField(), refs)
|
||||
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
18
api/pkg/db/internal/mongo/transactionimp/factory.go
Normal file
18
api/pkg/db/internal/mongo/transactionimp/factory.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package transactionimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type MongoTransactionFactory struct {
|
||||
client *mongo.Client
|
||||
}
|
||||
|
||||
func (mtf *MongoTransactionFactory) CreateTransaction() transaction.Transaction {
|
||||
return Create(mtf.client)
|
||||
}
|
||||
|
||||
func CreateFactory(client *mongo.Client) transaction.Factory {
|
||||
return &MongoTransactionFactory{client: client}
|
||||
}
|
||||
30
api/pkg/db/internal/mongo/transactionimp/transaction.go
Normal file
30
api/pkg/db/internal/mongo/transactionimp/transaction.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package transactionimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type MongoTransaction struct {
|
||||
client *mongo.Client
|
||||
}
|
||||
|
||||
func (mt *MongoTransaction) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
||||
session, err := mt.client.StartSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.EndSession(ctx)
|
||||
|
||||
callback := func(sessCtx mongo.SessionContext) (any, error) {
|
||||
return cb(sessCtx)
|
||||
}
|
||||
|
||||
return session.WithTransaction(ctx, callback)
|
||||
}
|
||||
|
||||
func Create(client *mongo.Client) *MongoTransaction {
|
||||
return &MongoTransaction{client: client}
|
||||
}
|
||||
118
api/pkg/db/internal/mongo/tseriesimp/tseries.go
Normal file
118
api/pkg/db/internal/mongo/tseriesimp/tseries.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package tseriesimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
rdecoder "github.com/tech/sendico/pkg/db/repository/decoder"
|
||||
tsoptions "github.com/tech/sendico/pkg/db/tseries/options"
|
||||
tspoint "github.com/tech/sendico/pkg/db/tseries/point"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type TimeSeries struct {
|
||||
options tsoptions.Options
|
||||
collection *mongo.Collection
|
||||
}
|
||||
|
||||
func NewMongoTimeSeriesCollection(ctx context.Context, db *mongo.Database, tsOpts *tsoptions.Options) (*TimeSeries, error) {
|
||||
if tsOpts == nil {
|
||||
return nil, merrors.InvalidArgument("nil time-series options provided")
|
||||
}
|
||||
// Configure time-series options
|
||||
granularity := tsOpts.Granularity.String()
|
||||
ts := &options.TimeSeriesOptions{
|
||||
TimeField: tsOpts.TimeField,
|
||||
Granularity: &granularity,
|
||||
}
|
||||
if tsOpts.MetaField != "" {
|
||||
ts.MetaField = &tsOpts.MetaField
|
||||
}
|
||||
|
||||
// Collection options
|
||||
collOpts := options.CreateCollection().SetTimeSeriesOptions(ts)
|
||||
|
||||
// Set TTL if requested
|
||||
if tsOpts.ExpireAfter > 0 {
|
||||
secs := int64(tsOpts.ExpireAfter / time.Second)
|
||||
collOpts.SetExpireAfterSeconds(secs)
|
||||
}
|
||||
|
||||
if err := db.CreateCollection(ctx, tsOpts.Collection, collOpts); err != nil {
|
||||
if cmdErr, ok := err.(mongo.CommandError); !ok || cmdErr.Code != 48 {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &TimeSeries{collection: db.Collection(tsOpts.Collection), options: *tsOpts}, nil
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rdecoder.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Aggregate(ctx, pipeline.Build())
|
||||
}
|
||||
return ts.executeQuery(ctx, decoder, queryFunc)
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Insert(ctx context.Context, timePoint tspoint.TimePoint) error {
|
||||
_, err := ts.collection.InsertOne(ctx, timePoint)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) InsertMany(ctx context.Context, timePoints []tspoint.TimePoint) error {
|
||||
docs := make([]any, len(timePoints))
|
||||
for i, p := range timePoints {
|
||||
docs[i] = p
|
||||
}
|
||||
|
||||
// ignore the result if you like, or capture it
|
||||
_, err := ts.collection.InsertMany(ctx, docs)
|
||||
return err
|
||||
}
|
||||
|
||||
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
|
||||
|
||||
func (ts *TimeSeries) executeQuery(ctx context.Context, decoder rdecoder.DecodingFunc, queryFunc QueryFunc) error {
|
||||
cursor, err := queryFunc(ctx, ts.collection)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData("no_items_in_array")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
if err := cursor.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = decoder(cursor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Query(ctx context.Context, decoder rdecoder.DecodingFunc, query builder.Query, from, to *time.Time) error {
|
||||
timeLimitedQuery := query
|
||||
if from != nil {
|
||||
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Gte, *from))
|
||||
}
|
||||
if to != nil {
|
||||
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Lte, *to))
|
||||
}
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Find(ctx, timeLimitedQuery.BuildQuery(), timeLimitedQuery.BuildOptions())
|
||||
}
|
||||
return ts.executeQuery(ctx, decoder, queryFunc)
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Name() string {
|
||||
return ts.collection.Name()
|
||||
}
|
||||
19
api/pkg/db/invitation/invitation.go
Normal file
19
api/pkg/db/invitation/invitation.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package invitation
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
auth.ProtectedDB[*model.Invitation]
|
||||
GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error)
|
||||
Accept(ctx context.Context, invitationRef primitive.ObjectID) error
|
||||
Decline(ctx context.Context, invitationRef primitive.ObjectID) error
|
||||
List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error)
|
||||
DeleteCascade(ctx context.Context, accountRef, statusRef primitive.ObjectID) error
|
||||
SetArchived(ctx context.Context, accountRef, organizationRef, statusRef primitive.ObjectID, archived, cascade bool) error
|
||||
}
|
||||
17
api/pkg/db/organization/organization.go
Normal file
17
api/pkg/db/organization/organization.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package organization
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// DB is the interface which must be implemented by all db drivers
|
||||
type DB interface {
|
||||
auth.ProtectedDB[*model.Organization]
|
||||
List(ctx context.Context, accountRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.Organization, error)
|
||||
ListOwned(ctx context.Context, accountRef primitive.ObjectID) ([]model.Organization, error)
|
||||
SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error
|
||||
}
|
||||
17
api/pkg/db/policy/policy.go
Normal file
17
api/pkg/db/policy/policy.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
template.DB[*model.PolicyDescription]
|
||||
All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error)
|
||||
Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error)
|
||||
GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error
|
||||
}
|
||||
17
api/pkg/db/refreshtokens/refreshtokens.go
Normal file
17
api/pkg/db/refreshtokens/refreshtokens.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package refreshtokens
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
template.DB[*model.RefreshToken]
|
||||
Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error
|
||||
RevokeAll(ctx context.Context, accountRef primitive.ObjectID, deviceID string) error
|
||||
GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error)
|
||||
GetClient(ctx context.Context, clientID string) (*model.Client, error)
|
||||
}
|
||||
78
api/pkg/db/repository/abfilter.go
Normal file
78
api/pkg/db/repository/abfilter.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// AccountBoundFilter provides factory methods for creating account-bound filters
|
||||
type AccountBoundFilter struct{}
|
||||
|
||||
// NewAccountBoundFilter creates a new AccountBoundFilter instance
|
||||
func NewAccountBoundFilter() *AccountBoundFilter {
|
||||
return &AccountBoundFilter{}
|
||||
}
|
||||
|
||||
// WithoutOrg creates a filter for account-bound objects without organization filter
|
||||
// This filter finds objects where:
|
||||
// - accountRef matches the provided accountRef, OR
|
||||
// - accountRef is nil/null, OR
|
||||
// - accountRef field doesn't exist
|
||||
func (f *AccountBoundFilter) WithoutOrg(accountRef primitive.ObjectID) builder.Query {
|
||||
return Query().Or(
|
||||
AccountFilter(accountRef),
|
||||
Filter(model.AccountRefField, nil),
|
||||
Exists(AccountField(), false),
|
||||
)
|
||||
}
|
||||
|
||||
// WithOrg creates a filter for account-bound objects with organization filter
|
||||
// This filter finds objects where:
|
||||
// - accountRef matches the provided accountRef, OR
|
||||
// - accountRef is nil/null, OR
|
||||
// - accountRef field doesn't exist
|
||||
// AND combines with organization filter
|
||||
func (f *AccountBoundFilter) WithOrg(accountRef, organizationRef primitive.ObjectID) builder.Query {
|
||||
return Query().And(
|
||||
OrgFilter(organizationRef),
|
||||
f.WithoutOrg(accountRef),
|
||||
)
|
||||
}
|
||||
|
||||
// WithQuery creates a filter for account-bound objects with additional query and organization filter
|
||||
func (f *AccountBoundFilter) WithQuery(accountRef, organizationRef primitive.ObjectID, additionalQuery builder.Query) builder.Query {
|
||||
accountQuery := f.WithOrg(accountRef, organizationRef)
|
||||
return additionalQuery.And(accountQuery)
|
||||
}
|
||||
|
||||
// WithQueryNoOrg creates a filter for account-bound objects with additional query but no org filter
|
||||
func (f *AccountBoundFilter) WithQueryNoOrg(accountRef primitive.ObjectID, additionalQuery builder.Query) builder.Query {
|
||||
accountQuery := f.WithoutOrg(accountRef)
|
||||
return additionalQuery.And(accountQuery)
|
||||
}
|
||||
|
||||
// Global instance for convenience
|
||||
var DefaultAccountBoundFilter = NewAccountBoundFilter()
|
||||
|
||||
// Convenience functions that use the global factory instance
|
||||
|
||||
// WithOrg is a convenience function that uses the default factory
|
||||
func WithOrg(accountRef, organizationRef primitive.ObjectID) builder.Query {
|
||||
return DefaultAccountBoundFilter.WithOrg(accountRef, organizationRef)
|
||||
}
|
||||
|
||||
// WithoutOrg is a convenience function that uses the default factory
|
||||
func WithoutOrg(accountRef primitive.ObjectID) builder.Query {
|
||||
return DefaultAccountBoundFilter.WithoutOrg(accountRef)
|
||||
}
|
||||
|
||||
// WithQuery is a convenience function that uses the default factory
|
||||
func WithQuery(accountRef, organizationRef primitive.ObjectID, additionalQuery builder.Query) builder.Query {
|
||||
return DefaultAccountBoundFilter.WithQuery(accountRef, organizationRef, additionalQuery)
|
||||
}
|
||||
|
||||
// WithQueryNoOrg is a convenience function that uses the default factory
|
||||
func WithQueryNoOrg(accountRef primitive.ObjectID, additionalQuery builder.Query) builder.Query {
|
||||
return DefaultAccountBoundFilter.WithQueryNoOrg(accountRef, additionalQuery)
|
||||
}
|
||||
11
api/pkg/db/repository/builder/accumulator.go
Normal file
11
api/pkg/db/repository/builder/accumulator.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package builder
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
|
||||
type Accumulator interface {
|
||||
Build() bson.D
|
||||
}
|
||||
|
||||
type GroupAccumulator interface {
|
||||
Build() bson.D
|
||||
}
|
||||
8
api/pkg/db/repository/builder/alias.go
Normal file
8
api/pkg/db/repository/builder/alias.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package builder
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
|
||||
type Alias interface {
|
||||
Field() Field
|
||||
Build() bson.D
|
||||
}
|
||||
7
api/pkg/db/repository/builder/array.go
Normal file
7
api/pkg/db/repository/builder/array.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package builder
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
|
||||
type Array interface {
|
||||
Build() bson.A
|
||||
}
|
||||
5
api/pkg/db/repository/builder/expression.go
Normal file
5
api/pkg/db/repository/builder/expression.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package builder
|
||||
|
||||
type Expression interface {
|
||||
Build() any
|
||||
}
|
||||
7
api/pkg/db/repository/builder/field.go
Normal file
7
api/pkg/db/repository/builder/field.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package builder
|
||||
|
||||
type Field interface {
|
||||
Dot(field string) Field
|
||||
CopyWith(field string) Field
|
||||
Build() string
|
||||
}
|
||||
16
api/pkg/db/repository/builder/keyword.go
Normal file
16
api/pkg/db/repository/builder/keyword.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package builder
|
||||
|
||||
type MongoKeyword string
|
||||
|
||||
const (
|
||||
MKAs MongoKeyword = "as"
|
||||
MKForeignField MongoKeyword = "foreignField"
|
||||
MKFrom MongoKeyword = "from"
|
||||
MKIncludeArrayIndex MongoKeyword = "includeArrayIndex"
|
||||
MKLet MongoKeyword = "let"
|
||||
MKLocalField MongoKeyword = "localField"
|
||||
MKPath MongoKeyword = "path"
|
||||
MKPipeline MongoKeyword = "pipeline"
|
||||
MKPreserveNullAndEmptyArrays MongoKeyword = "preserveNullAndEmptyArrays"
|
||||
MKNewRoot MongoKeyword = "newRoot"
|
||||
)
|
||||
57
api/pkg/db/repository/builder/operators.go
Normal file
57
api/pkg/db/repository/builder/operators.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package builder
|
||||
|
||||
type MongoOperation string
|
||||
|
||||
const (
|
||||
// Comparison operators
|
||||
Gt MongoOperation = "$gt"
|
||||
Lt MongoOperation = "$lt"
|
||||
Gte MongoOperation = "$gte"
|
||||
Lte MongoOperation = "$lte"
|
||||
Eq MongoOperation = "$eq"
|
||||
Ne MongoOperation = "$ne"
|
||||
In MongoOperation = "$in"
|
||||
NotIn MongoOperation = "$nin"
|
||||
Exists MongoOperation = "$exists"
|
||||
|
||||
// Logical operators
|
||||
And MongoOperation = "$and"
|
||||
Or MongoOperation = "$or"
|
||||
Not MongoOperation = "$not"
|
||||
|
||||
AddToSet MongoOperation = "$addToSet"
|
||||
Avg MongoOperation = "$avg"
|
||||
Pull MongoOperation = "$pull"
|
||||
Count MongoOperation = "$count"
|
||||
Cond MongoOperation = "$cond"
|
||||
Each MongoOperation = "$each"
|
||||
Expr MongoOperation = "$expr"
|
||||
First MongoOperation = "$first"
|
||||
Group MongoOperation = "$group"
|
||||
IfNull MongoOperation = "$ifNull"
|
||||
Limit MongoOperation = "$limit"
|
||||
Literal MongoOperation = "$literal"
|
||||
Lookup MongoOperation = "$lookup"
|
||||
Match MongoOperation = "$match"
|
||||
Max MongoOperation = "$max"
|
||||
Min MongoOperation = "$min"
|
||||
Push MongoOperation = "$push"
|
||||
Project MongoOperation = "$project"
|
||||
Set MongoOperation = "$set"
|
||||
Inc MongoOperation = "$inc"
|
||||
Unset MongoOperation = "$unset"
|
||||
Rename MongoOperation = "$rename"
|
||||
ReplaceRoot MongoOperation = "$replaceRoot"
|
||||
SetUnion MongoOperation = "$setUnion"
|
||||
Size MongoOperation = "$size"
|
||||
Sort MongoOperation = "$sort"
|
||||
Skip MongoOperation = "$skip"
|
||||
Sum MongoOperation = "$sum"
|
||||
Type MongoOperation = "$type"
|
||||
Unwind MongoOperation = "$unwind"
|
||||
|
||||
Add MongoOperation = "$add"
|
||||
Subtract MongoOperation = "$subtract"
|
||||
Multiply MongoOperation = "$multiply"
|
||||
Divide MongoOperation = "$divide"
|
||||
)
|
||||
16
api/pkg/db/repository/builder/patch.go
Normal file
16
api/pkg/db/repository/builder/patch.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package builder
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
|
||||
// Patch defines operations for constructing partial update documents.
|
||||
// Each builder method returns the same Patch instance to allow chaining.
|
||||
type Patch interface {
|
||||
Set(field Field, value any) Patch
|
||||
Inc(field Field, value any) Patch
|
||||
Unset(field Field) Patch
|
||||
Rename(field Field, newName string) Patch
|
||||
Push(field Field, value any) Patch
|
||||
Pull(field Field, value any) Patch
|
||||
AddToSet(field Field, value any) Patch
|
||||
Build() bson.D
|
||||
}
|
||||
24
api/pkg/db/repository/builder/pipeline.go
Normal file
24
api/pkg/db/repository/builder/pipeline.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type Pipeline interface {
|
||||
Match(filter Query) Pipeline
|
||||
Lookup(from mservice.Type, localField, foreignField, as Field) Pipeline
|
||||
LookupWithPipeline(
|
||||
from mservice.Type,
|
||||
pipeline Pipeline, // your nested pipeline
|
||||
as Field,
|
||||
let *map[string]Field, // optional e.g. {"projRef": Field("$_id")}
|
||||
) Pipeline
|
||||
// unwind with functional options
|
||||
Unwind(path Field, opts ...UnwindOption) Pipeline
|
||||
Count(field Field) Pipeline
|
||||
Group(groupBy Alias, accumulators ...GroupAccumulator) Pipeline
|
||||
Project(projections ...Projection) Pipeline
|
||||
ReplaceRoot(newRoot Expression) Pipeline
|
||||
Build() mongo.Pipeline
|
||||
}
|
||||
7
api/pkg/db/repository/builder/projection.go
Normal file
7
api/pkg/db/repository/builder/projection.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package builder
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
|
||||
type Projection interface {
|
||||
Build() bson.D
|
||||
}
|
||||
24
api/pkg/db/repository/builder/query.go
Normal file
24
api/pkg/db/repository/builder/query.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package builder
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type Query interface {
|
||||
Filter(field Field, value any) Query
|
||||
And(filters ...Query) Query
|
||||
Or(filters ...Query) Query
|
||||
Expression(value Expression) Query
|
||||
Comparison(field Field, operator MongoOperation, value any) Query
|
||||
RegEx(field Field, pattern, options string) Query
|
||||
In(field Field, values ...any) Query
|
||||
NotIn(field Field, values ...any) Query
|
||||
Sort(field Field, ascending bool) Query
|
||||
Limit(limit *int64) Query
|
||||
Offset(offset *int64) Query
|
||||
Archived(isArchived *bool) Query
|
||||
BuildPipeline() bson.D
|
||||
BuildQuery() bson.D
|
||||
BuildOptions() *options.FindOptions
|
||||
}
|
||||
23
api/pkg/db/repository/builder/unwind.go
Normal file
23
api/pkg/db/repository/builder/unwind.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package builder
|
||||
|
||||
// UnwindOption is a functional option for configuring the $unwind stage.
|
||||
type UnwindOption func(*UnwindOpts)
|
||||
|
||||
type UnwindOpts struct {
|
||||
PreserveNullAndEmptyArrays bool
|
||||
IncludeArrayIndex string
|
||||
}
|
||||
|
||||
// WithPreserveNullAndEmptyArrays tells $unwind to keep docs where the array is null/empty.
|
||||
func WithPreserveNullAndEmptyArrays() UnwindOption {
|
||||
return func(o *UnwindOpts) {
|
||||
o.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithIncludeArrayIndex adds an array‐index field named idxField to each unwound doc.
|
||||
func WithIncludeArrayIndex(idxField string) UnwindOption {
|
||||
return func(o *UnwindOpts) {
|
||||
o.IncludeArrayIndex = idxField
|
||||
}
|
||||
}
|
||||
5
api/pkg/db/repository/builder/value.go
Normal file
5
api/pkg/db/repository/builder/value.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package builder
|
||||
|
||||
type Value interface {
|
||||
Build() any
|
||||
}
|
||||
273
api/pkg/db/repository/builders.go
Normal file
273
api/pkg/db/repository/builders.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func Query() builder.Query {
|
||||
return builderimp.NewQueryImp()
|
||||
}
|
||||
|
||||
func Filter(field string, value any) builder.Query {
|
||||
return Query().Filter(Field(field), value)
|
||||
}
|
||||
|
||||
func Field(baseName string) builder.Field {
|
||||
return builderimp.NewFieldImp(baseName)
|
||||
}
|
||||
|
||||
func Ref(field builder.Field) builder.Field {
|
||||
return builderimp.NewRefFieldImp(field)
|
||||
}
|
||||
|
||||
func RootRef() builder.Field {
|
||||
return builderimp.NewRootRef()
|
||||
}
|
||||
|
||||
func RemoveRef() builder.Field {
|
||||
return builderimp.NewRemoveRef()
|
||||
}
|
||||
|
||||
func Pipeline() builder.Pipeline {
|
||||
return builderimp.NewPipelineImp()
|
||||
}
|
||||
|
||||
func IDField() builder.Field {
|
||||
return Field(storable.IDField)
|
||||
}
|
||||
|
||||
func NameField() builder.Field {
|
||||
return Field(model.NameField)
|
||||
}
|
||||
|
||||
func DescrtiptionField() builder.Field {
|
||||
return Field(model.DescriptionField)
|
||||
}
|
||||
|
||||
func IsArchivedField() builder.Field {
|
||||
return Field(storable.IsArchivedField)
|
||||
}
|
||||
|
||||
func IDFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(IDField(), ref)
|
||||
}
|
||||
|
||||
func ArchivedFilter() builder.Query {
|
||||
return IsArchivedFilter(true)
|
||||
}
|
||||
|
||||
func NotArchivedFilter() builder.Query {
|
||||
return IsArchivedFilter(false)
|
||||
}
|
||||
|
||||
func IsArchivedFilter(isArchived bool) builder.Query {
|
||||
return Query().Filter(IsArchivedField(), isArchived)
|
||||
}
|
||||
|
||||
func OrgField() builder.Field {
|
||||
return Field(storable.OrganizationRefField)
|
||||
}
|
||||
|
||||
func OrgFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(OrgField(), ref)
|
||||
}
|
||||
|
||||
func ProjectField() builder.Field {
|
||||
return Field("projectRef")
|
||||
}
|
||||
|
||||
func ProjectFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(ProjectField(), ref)
|
||||
}
|
||||
|
||||
func AccountField() builder.Field {
|
||||
return Field(model.AccountRefField)
|
||||
}
|
||||
|
||||
func AccountFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(AccountField(), ref)
|
||||
}
|
||||
|
||||
func StatusRefField() builder.Field {
|
||||
return Field("statusRef")
|
||||
}
|
||||
|
||||
func StatusRefFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(StatusRefField(), ref)
|
||||
}
|
||||
|
||||
func PriorityRefField() builder.Field {
|
||||
return Field("priorityRef")
|
||||
}
|
||||
|
||||
func PriorityRefFilter(ref primitive.ObjectID) builder.Query {
|
||||
return Query().Filter(PriorityRefField(), ref)
|
||||
}
|
||||
|
||||
func IndexField() builder.Field {
|
||||
return Field("index")
|
||||
}
|
||||
|
||||
func IndexFilter(index int) builder.Query {
|
||||
return Query().Filter(IndexField(), index)
|
||||
}
|
||||
|
||||
func TagRefsField() builder.Field {
|
||||
return Field(model.TagRefsField)
|
||||
}
|
||||
|
||||
func IndexOpFilter(index int, operation builder.MongoOperation) builder.Query {
|
||||
return Query().Comparison(IndexField(), operation, index)
|
||||
}
|
||||
|
||||
func Patch() builder.Patch {
|
||||
return builderimp.NewPatchImp()
|
||||
}
|
||||
|
||||
func Accumulator(operator builder.MongoOperation, value any) builder.Accumulator {
|
||||
return builderimp.NewAccumulator(operator, value)
|
||||
}
|
||||
|
||||
func GroupAccumulator(field builder.Field, acc builder.Accumulator) builder.GroupAccumulator {
|
||||
return builderimp.NewGroupAccumulator(field, acc)
|
||||
}
|
||||
|
||||
func Literal(value any) builder.Expression {
|
||||
return builderimp.NewLiteralExpression(value)
|
||||
}
|
||||
|
||||
func Projection(alias builder.Alias) builder.Projection {
|
||||
return builderimp.NewAliasProjection(alias)
|
||||
}
|
||||
|
||||
func IncludeField(field builder.Field) builder.Projection {
|
||||
return builderimp.IncludeField(field)
|
||||
}
|
||||
|
||||
func ExcludeField(field builder.Field) builder.Projection {
|
||||
return builderimp.ExcludeField(field)
|
||||
}
|
||||
|
||||
func ProjectionExpr(field builder.Field, expr builder.Expression) builder.Projection {
|
||||
return builderimp.NewProjectionExpr(field, expr)
|
||||
}
|
||||
|
||||
func NullAlias(lhs builder.Field) builder.Alias {
|
||||
return builderimp.NewNullAlias(lhs)
|
||||
}
|
||||
|
||||
func SimpleAlias(lhs, rhs builder.Field) builder.Alias {
|
||||
return builderimp.NewSimpleAlias(lhs, rhs)
|
||||
}
|
||||
|
||||
func ComplexAlias(lhs builder.Field, rhs []builder.Alias) builder.Alias {
|
||||
return builderimp.NewComplexAlias(lhs, rhs)
|
||||
}
|
||||
|
||||
func Aliases(aliases ...builder.Alias) builder.Alias {
|
||||
return builderimp.NewAliases(aliases...)
|
||||
}
|
||||
|
||||
func AddToSet(value builder.Expression) builder.Expression {
|
||||
return builderimp.AddToSet(value)
|
||||
}
|
||||
|
||||
func Size(value builder.Expression) builder.Expression {
|
||||
return builderimp.Size(value)
|
||||
}
|
||||
|
||||
func InRef(value builder.Field) builder.Expression {
|
||||
return builderimp.InRef(value)
|
||||
}
|
||||
|
||||
func In(values ...any) builder.Expression {
|
||||
return builderimp.In(values)
|
||||
}
|
||||
|
||||
func Cond(condition builder.Expression, ifTrue, ifFalse any) builder.Expression {
|
||||
return builderimp.NewCond(condition, ifTrue, ifFalse)
|
||||
}
|
||||
|
||||
func And(exprs ...builder.Expression) builder.Expression {
|
||||
return builderimp.NewAnd(exprs...)
|
||||
}
|
||||
|
||||
func Or(exprs ...builder.Expression) builder.Expression {
|
||||
return builderimp.NewOr(exprs...)
|
||||
}
|
||||
|
||||
func Type(expr builder.Expression) builder.Expression {
|
||||
return builderimp.NewType(expr)
|
||||
}
|
||||
|
||||
func Not(expression builder.Expression) builder.Expression {
|
||||
return builderimp.NewNot(expression)
|
||||
}
|
||||
|
||||
func Sum(expression builder.Expression) builder.Expression {
|
||||
return builderimp.NewSum(expression)
|
||||
}
|
||||
|
||||
func Assign(field builder.Field, expression builder.Expression) builder.Projection {
|
||||
return builderimp.NewAssignment(field, expression)
|
||||
}
|
||||
|
||||
func SetUnion(exprs ...builder.Expression) builder.Expression {
|
||||
return builderimp.NewSetUnion(exprs...)
|
||||
}
|
||||
|
||||
func Eq(left, right builder.Expression) builder.Expression {
|
||||
return builderimp.Eq(left, right)
|
||||
}
|
||||
|
||||
func Gt(left, right builder.Expression) builder.Expression {
|
||||
return builderimp.Gt(left, right)
|
||||
}
|
||||
|
||||
func Lt(left, right builder.Expression) builder.Expression {
|
||||
return builderimp.NewLt(left, right)
|
||||
}
|
||||
|
||||
func Array(expressions ...builder.Expression) builder.Array {
|
||||
return builderimp.NewArray(expressions...)
|
||||
}
|
||||
|
||||
func IfNull(cond, replacement builder.Expression) builder.Expression {
|
||||
return builderimp.NewIfNull(cond, replacement)
|
||||
}
|
||||
|
||||
func Each(exprs ...builder.Expression) builder.Expression {
|
||||
return builderimp.NewEach(exprs...)
|
||||
}
|
||||
|
||||
func Push(expression builder.Expression) builder.Expression {
|
||||
return builderimp.NewPush(expression)
|
||||
}
|
||||
|
||||
func Min(expression builder.Expression) builder.Expression {
|
||||
return builderimp.NewMin(expression)
|
||||
}
|
||||
|
||||
func Ne(left, right builder.Expression) builder.Expression {
|
||||
return builderimp.Ne(left, right)
|
||||
}
|
||||
|
||||
func Compute(field builder.Field, expression builder.Expression) builder.Expression {
|
||||
return builderimp.NewCompute(field, expression)
|
||||
}
|
||||
|
||||
func First(expr builder.Expression) builder.Expression {
|
||||
return builderimp.First(expr)
|
||||
}
|
||||
|
||||
func Value(value any) builder.Value {
|
||||
return builderimp.NewValue(value)
|
||||
}
|
||||
|
||||
func Exists(field builder.Field, exists bool) builder.Query {
|
||||
return Query().Comparison(field, builder.Exists, exists)
|
||||
}
|
||||
19
api/pkg/db/repository/cursor.go
Normal file
19
api/pkg/db/repository/cursor.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
// ApplyCursor adds pagination and archival filters to the provided query.
|
||||
func ApplyCursor(query builder.Query, cursor *model.ViewCursor) builder.Query {
|
||||
if cursor == nil {
|
||||
return query
|
||||
}
|
||||
|
||||
query = query.Limit(cursor.Limit)
|
||||
query = query.Offset(cursor.Offset)
|
||||
query = query.Archived(cursor.IsArchived)
|
||||
|
||||
return query
|
||||
}
|
||||
5
api/pkg/db/repository/decoder/decoder.go
Normal file
5
api/pkg/db/repository/decoder/decoder.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package repository
|
||||
|
||||
import "go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
type DecodingFunc = func(r *mongo.Cursor) error
|
||||
93
api/pkg/db/repository/filter_factory_test.go
Normal file
93
api/pkg/db/repository/filter_factory_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func TestAccountBoundFilter_WithOrg(t *testing.T) {
|
||||
factory := NewAccountBoundFilter()
|
||||
accountRef := primitive.NewObjectID()
|
||||
orgRef := primitive.NewObjectID()
|
||||
|
||||
query := factory.WithOrg(accountRef, orgRef)
|
||||
|
||||
// Test that the query is not nil
|
||||
assert.NotNil(t, query)
|
||||
}
|
||||
|
||||
func TestAccountBoundFilter_WithoutOrg(t *testing.T) {
|
||||
factory := NewAccountBoundFilter()
|
||||
accountRef := primitive.NewObjectID()
|
||||
|
||||
query := factory.WithoutOrg(accountRef)
|
||||
|
||||
// Test that the query is not nil
|
||||
assert.NotNil(t, query)
|
||||
}
|
||||
|
||||
func TestAccountBoundFilter_WithQuery(t *testing.T) {
|
||||
factory := NewAccountBoundFilter()
|
||||
accountRef := primitive.NewObjectID()
|
||||
orgRef := primitive.NewObjectID()
|
||||
additionalQuery := Query().Filter(Field("status"), "active")
|
||||
|
||||
query := factory.WithQuery(accountRef, orgRef, additionalQuery)
|
||||
|
||||
// Test that the query is not nil
|
||||
assert.NotNil(t, query)
|
||||
}
|
||||
|
||||
func TestAccountBoundFilter_WithQueryNoOrg(t *testing.T) {
|
||||
factory := NewAccountBoundFilter()
|
||||
accountRef := primitive.NewObjectID()
|
||||
additionalQuery := Query().Filter(Field("status"), "active")
|
||||
|
||||
query := factory.WithQueryNoOrg(accountRef, additionalQuery)
|
||||
|
||||
// Test that the query is not nil
|
||||
assert.NotNil(t, query)
|
||||
}
|
||||
|
||||
func TestDefaultAccountBoundFilter(t *testing.T) {
|
||||
// Test that the default factory is not nil
|
||||
assert.NotNil(t, DefaultAccountBoundFilter)
|
||||
|
||||
// Test that it's the correct type
|
||||
assert.IsType(t, &AccountBoundFilter{}, DefaultAccountBoundFilter)
|
||||
}
|
||||
|
||||
func TestConvenienceFunctions(t *testing.T) {
|
||||
accountRef := primitive.NewObjectID()
|
||||
orgRef := primitive.NewObjectID()
|
||||
additionalQuery := Query().Filter(Field("status"), "active")
|
||||
|
||||
// Test convenience functions
|
||||
query1 := WithOrg(accountRef, orgRef)
|
||||
assert.NotNil(t, query1)
|
||||
|
||||
query2 := WithoutOrg(accountRef)
|
||||
assert.NotNil(t, query2)
|
||||
|
||||
query3 := WithQuery(accountRef, orgRef, additionalQuery)
|
||||
assert.NotNil(t, query3)
|
||||
|
||||
query4 := WithQueryNoOrg(accountRef, additionalQuery)
|
||||
assert.NotNil(t, query4)
|
||||
}
|
||||
|
||||
func TestFilterFactoryConsistency(t *testing.T) {
|
||||
factory := NewAccountBoundFilter()
|
||||
accountRef := primitive.NewObjectID()
|
||||
orgRef := primitive.NewObjectID()
|
||||
|
||||
// Test that factory methods and convenience functions produce the same result
|
||||
query1 := factory.WithOrg(accountRef, orgRef)
|
||||
query2 := WithOrg(accountRef, orgRef)
|
||||
|
||||
// Both should be valid queries
|
||||
assert.NotNil(t, query1)
|
||||
assert.NotNil(t, query2)
|
||||
}
|
||||
21
api/pkg/db/repository/index/index.go
Normal file
21
api/pkg/db/repository/index/index.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repository
|
||||
|
||||
type Sort int8
|
||||
|
||||
const (
|
||||
Asc Sort = 1
|
||||
Desc Sort = -1
|
||||
)
|
||||
|
||||
type Key struct {
|
||||
Field string
|
||||
Sort Sort // 1 or -1. 0 means “use Type”.
|
||||
Type IndexType // optional: "text", "2dsphere", ...
|
||||
}
|
||||
|
||||
type Definition struct {
|
||||
Keys []Key // mandatory, at least one element
|
||||
Unique bool // unique constraint?
|
||||
TTL *int32 // seconds; nil means “no TTL”
|
||||
Name string // optional explicit name
|
||||
}
|
||||
36
api/pkg/db/repository/index/types.go
Normal file
36
api/pkg/db/repository/index/types.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package repository
|
||||
|
||||
// IndexType represents a supported MongoDB index type.
|
||||
type IndexType string
|
||||
|
||||
const (
|
||||
// IndexTypeNotSet is a default index type
|
||||
IndexTypeNotSet IndexType = ""
|
||||
|
||||
// IndexTypeSingleField is a single-field index.
|
||||
IndexTypeSingleField IndexType = "single"
|
||||
|
||||
// IndexTypeCompound is a compound index on multiple fields.
|
||||
IndexTypeCompound IndexType = "compound"
|
||||
|
||||
// IndexTypeMultikey is an index on array fields (created automatically when needed).
|
||||
IndexTypeMultikey IndexType = "multikey"
|
||||
|
||||
// IndexTypeText is a text index for full-text search.
|
||||
IndexTypeText IndexType = "text"
|
||||
|
||||
// IndexTypeGeo2D is a legacy 2D geospatial index for planar geometry.
|
||||
IndexTypeGeo2D IndexType = "2d"
|
||||
|
||||
// IndexTypeGeo2DSphere is a 2dsphere geospatial index for GeoJSON data.
|
||||
IndexTypeGeo2DSphere IndexType = "2dsphere"
|
||||
|
||||
// IndexTypeHashed is a hashed index for sharding and efficient equality queries.
|
||||
IndexTypeHashed IndexType = "hashed"
|
||||
|
||||
// IndexTypeWildcard is a wildcard index to index all fields or subpaths.
|
||||
IndexTypeWildcard IndexType = "wildcard"
|
||||
|
||||
// IndexTypeClustered is a clustered index that orders the collection on the index key.
|
||||
IndexTypeClustered IndexType = "clustered"
|
||||
)
|
||||
46
api/pkg/db/repository/repository.go
Normal file
46
api/pkg/db/repository/repository.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
rd "github.com/tech/sendico/pkg/db/repository/decoder"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type (
|
||||
// FilterQuery selects documents to operate on.
|
||||
FilterQuery = builder.Query
|
||||
// PatchDoc defines field/value modifications for partial updates.
|
||||
PatchDoc = builder.Patch
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Aggregate(ctx context.Context, builder builder.Pipeline, decoder rd.DecodingFunc) error
|
||||
Insert(ctx context.Context, obj storable.Storable, getFilter builder.Query) error
|
||||
InsertMany(ctx context.Context, objects []storable.Storable) error
|
||||
Get(ctx context.Context, id primitive.ObjectID, result storable.Storable) error
|
||||
FindOneByFilter(ctx context.Context, builder builder.Query, result storable.Storable) error
|
||||
FindManyByFilter(ctx context.Context, builder builder.Query, decoder rd.DecodingFunc) error
|
||||
Update(ctx context.Context, obj storable.Storable) error
|
||||
// Patch applies partial updates defined by patch to the document identified by id.
|
||||
Patch(ctx context.Context, id primitive.ObjectID, patch PatchDoc) error
|
||||
// PatchMany applies partial updates defined by patch to all documents matching filter and returns the number of updated documents.
|
||||
PatchMany(ctx context.Context, filter FilterQuery, patch PatchDoc) (int, error)
|
||||
Delete(ctx context.Context, id primitive.ObjectID) error
|
||||
DeleteMany(ctx context.Context, query builder.Query) error
|
||||
CreateIndex(def *ri.Definition) error
|
||||
ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error)
|
||||
ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error)
|
||||
ListAccountBound(ctx context.Context, query builder.Query) ([]model.AccountBoundStorable, error)
|
||||
Collection() string
|
||||
}
|
||||
|
||||
func CreateMongoRepository(db *mongo.Database, collection string) Repository {
|
||||
return repositoryimp.NewMongoRepository(db, collection)
|
||||
}
|
||||
15
api/pkg/db/role/role.go
Normal file
15
api/pkg/db/role/role.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package role
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
template.DB[*model.RoleDescription]
|
||||
Roles(ctx context.Context, refs []primitive.ObjectID) ([]model.RoleDescription, error)
|
||||
List(ctx context.Context, organizationRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.RoleDescription, error)
|
||||
}
|
||||
39
api/pkg/db/storable/id.go
Normal file
39
api/pkg/db/storable/id.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package storable
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
const (
|
||||
IDField = "_id"
|
||||
PermissionRefField = "permissionRef"
|
||||
OrganizationRefField = "organizationRef"
|
||||
IsArchivedField = "isArchived"
|
||||
UpdatedAtField = "updatedAt"
|
||||
)
|
||||
|
||||
type Base struct {
|
||||
ID primitive.ObjectID `bson:"_id" json:"id"`
|
||||
CreatedAt time.Time `bson:"createdAt" json:"createdAt"` // Timestamp for when the comment was created
|
||||
UpdatedAt time.Time `bson:"updatedAt" json:"updatedAt"` // Timestamp for when the comment was last updated (optional)
|
||||
}
|
||||
|
||||
func (b *Base) GetID() *primitive.ObjectID {
|
||||
return &b.ID
|
||||
}
|
||||
|
||||
func (b *Base) SetID(objID primitive.ObjectID) {
|
||||
b.ID = objID
|
||||
b.CreatedAt = time.Now()
|
||||
b.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
func (b *Base) Update() {
|
||||
b.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
func (b *Base) Collection() string {
|
||||
return "base"
|
||||
}
|
||||
11
api/pkg/db/storable/ref.go
Normal file
11
api/pkg/db/storable/ref.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package storable
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
const (
|
||||
RefField = "ref"
|
||||
)
|
||||
|
||||
type Ref struct {
|
||||
Ref primitive.ObjectID `bson:"ref" json:"ref"`
|
||||
}
|
||||
10
api/pkg/db/storable/storable.go
Normal file
10
api/pkg/db/storable/storable.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package storable
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
type Storable interface {
|
||||
GetID() *primitive.ObjectID
|
||||
SetID(objID primitive.ObjectID)
|
||||
Update()
|
||||
Collection() string
|
||||
}
|
||||
16
api/pkg/db/tag/tag.go
Normal file
16
api/pkg/db/tag/tag.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package tag
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
auth.ProtectedDB[*model.Tag]
|
||||
List(ctx context.Context, accountRef, organizationRef, parentRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.Tag, error)
|
||||
All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.Tag, error)
|
||||
SetArchived(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID, archived, cascade bool) error
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user