package organizationdb import ( "context" "errors" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tech/sendico/pkg/auth" "github.com/tech/sendico/pkg/db/repository/builder" rd "github.com/tech/sendico/pkg/db/repository/decoder" ri "github.com/tech/sendico/pkg/db/repository/index" "github.com/tech/sendico/pkg/db/storable" "github.com/tech/sendico/pkg/db/template" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/mservice" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" ) func TestOrganizationDB_SetArchived_TogglesState(t *testing.T) { ctx := context.Background() accountRef := primitive.NewObjectID() orgDB := newTestOrganizationDB(t) org := &model.Organization{ OrganizationBase: model.OrganizationBase{ Describable: model.Describable{Name: "Sendico"}, TimeZone: "UTC", }, } org.SetID(primitive.NewObjectID()) require.NoError(t, orgDB.Create(ctx, accountRef, *org.GetID(), org)) var stored model.Organization require.NoError(t, orgDB.Get(ctx, accountRef, *org.GetID(), &stored)) assert.False(t, stored.IsArchived()) require.NoError(t, orgDB.SetArchived(ctx, accountRef, *org.GetID(), true, false)) require.NoError(t, orgDB.Get(ctx, accountRef, *org.GetID(), &stored)) assert.True(t, stored.IsArchived()) require.NoError(t, orgDB.SetArchived(ctx, accountRef, *org.GetID(), false, false)) require.NoError(t, orgDB.Get(ctx, accountRef, *org.GetID(), &stored)) assert.False(t, stored.IsArchived()) } func TestOrganizationDB_SetArchived_UnknownOrganization(t *testing.T) { ctx := context.Background() accountRef := primitive.NewObjectID() orgDB := newTestOrganizationDB(t) err := orgDB.SetArchived(ctx, accountRef, primitive.NewObjectID(), true, false) require.Error(t, err) assert.True(t, errors.Is(err, merrors.ErrNoData)) } // newTestOrganizationDB wires the real OrganizationDB implementation to an in-memory repository // so the tests exercise actual SetArchived behavior without external dependencies. func newTestOrganizationDB(t *testing.T) *OrganizationDB { t.Helper() repo := newMemoryOrganizationRepository() logger := zap.NewNop() dbImp := &template.DBImp[*model.Organization]{ Logger: logger, Repository: repo, } dbImp.SetDeleter(func(ctx context.Context, objectRef primitive.ObjectID) error { return repo.Delete(ctx, objectRef) }) return &OrganizationDB{ ProtectedDBImp: auth.ProtectedDBImp[*model.Organization]{ DBImp: dbImp, Enforcer: allowAllEnforcer{}, PermissionRef: primitive.NewObjectID(), Collection: mservice.Organizations, }, } } type allowAllEnforcer struct{} func (allowAllEnforcer) Enforce(context.Context, primitive.ObjectID, primitive.ObjectID, primitive.ObjectID, primitive.ObjectID, model.Action) (bool, error) { return true, nil } func (allowAllEnforcer) EnforceBatch(_ context.Context, objects []model.PermissionBoundStorable, _ primitive.ObjectID, _ model.Action) (map[primitive.ObjectID]bool, error) { result := make(map[primitive.ObjectID]bool, len(objects)) for _, obj := range objects { result[*obj.GetID()] = true } return result, nil } func (allowAllEnforcer) GetRoles(context.Context, primitive.ObjectID, primitive.ObjectID) ([]model.Role, error) { return nil, nil } func (allowAllEnforcer) GetPermissions(context.Context, primitive.ObjectID, primitive.ObjectID) ([]model.Role, []model.Permission, error) { return nil, nil, nil } type memoryOrganizationRepository struct { mu sync.RWMutex data map[primitive.ObjectID]*model.Organization order []primitive.ObjectID } func newMemoryOrganizationRepository() *memoryOrganizationRepository { return &memoryOrganizationRepository{ data: make(map[primitive.ObjectID]*model.Organization), } } func (m *memoryOrganizationRepository) Aggregate(context.Context, builder.Pipeline, rd.DecodingFunc) error { return merrors.NotImplemented("aggregate is not supported in memory repository") } func (m *memoryOrganizationRepository) Insert(_ context.Context, obj storable.Storable, _ builder.Query) error { m.mu.Lock() defer m.mu.Unlock() org, ok := obj.(*model.Organization) if !ok { return merrors.InvalidDataType("expected organization") } id := org.GetID() if id == nil || *id == primitive.NilObjectID { return merrors.InvalidArgument("organization ID must be set") } if _, exists := m.data[*id]; exists { return merrors.DataConflict("organization already exists") } m.data[*id] = cloneOrganization(org) m.order = append(m.order, *id) return nil } func (m *memoryOrganizationRepository) InsertMany(ctx context.Context, objects []storable.Storable) error { for _, obj := range objects { if err := m.Insert(ctx, obj, nil); err != nil { return err } } return nil } func (m *memoryOrganizationRepository) Get(_ context.Context, id primitive.ObjectID, result storable.Storable) error { m.mu.RLock() defer m.mu.RUnlock() org, ok := m.data[id] if !ok { return merrors.ErrNoData } dst, ok := result.(*model.Organization) if !ok { return merrors.InvalidDataType("expected organization result") } *dst = *cloneOrganization(org) return nil } func (m *memoryOrganizationRepository) FindOneByFilter(_ context.Context, query builder.Query, result storable.Storable) error { m.mu.RLock() defer m.mu.RUnlock() for _, id := range m.order { if org, ok := m.data[id]; ok && m.matchesQuery(query, org) { dst, okCast := result.(*model.Organization) if !okCast { return merrors.InvalidDataType("expected organization result") } *dst = *cloneOrganization(org) return nil } } return merrors.ErrNoData } func (m *memoryOrganizationRepository) FindManyByFilter(context.Context, builder.Query, rd.DecodingFunc) error { return merrors.NotImplemented("FindManyByFilter is not supported in memory repository") } func (m *memoryOrganizationRepository) Update(_ context.Context, obj storable.Storable) error { m.mu.Lock() defer m.mu.Unlock() org, ok := obj.(*model.Organization) if !ok { return merrors.InvalidDataType("expected organization") } id := org.GetID() if id == nil { return merrors.InvalidArgument("organization ID must be set") } if _, exists := m.data[*id]; !exists { return merrors.ErrNoData } m.data[*id] = cloneOrganization(org) return nil } func (m *memoryOrganizationRepository) Patch(context.Context, primitive.ObjectID, builder.Patch) error { return merrors.NotImplemented("Patch is not supported in memory repository") } func (m *memoryOrganizationRepository) PatchMany(context.Context, builder.Query, builder.Patch) (int, error) { return 0, merrors.NotImplemented("PatchMany is not supported in memory repository") } func (m *memoryOrganizationRepository) Delete(_ context.Context, id primitive.ObjectID) error { m.mu.Lock() defer m.mu.Unlock() if _, exists := m.data[id]; !exists { return merrors.ErrNoData } delete(m.data, id) return nil } func (m *memoryOrganizationRepository) DeleteMany(context.Context, builder.Query) error { return merrors.NotImplemented("DeleteMany is not supported in memory repository") } func (m *memoryOrganizationRepository) CreateIndex(*ri.Definition) error { return nil } func (m *memoryOrganizationRepository) ListIDs(_ context.Context, query builder.Query) ([]primitive.ObjectID, error) { m.mu.RLock() defer m.mu.RUnlock() var ids []primitive.ObjectID for _, id := range m.order { if org, ok := m.data[id]; ok && m.matchesQuery(query, org) { ids = append(ids, id) } } if len(ids) == 0 { return nil, merrors.ErrNoData } return ids, nil } func (m *memoryOrganizationRepository) ListPermissionBound(_ context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) { m.mu.RLock() defer m.mu.RUnlock() var res []model.PermissionBoundStorable for _, org := range m.data { if m.matchesQuery(query, org) { res = append(res, cloneOrganization(org)) } } return res, nil } func (m *memoryOrganizationRepository) ListAccountBound(context.Context, builder.Query) ([]model.AccountBoundStorable, error) { return nil, merrors.NotImplemented("Account bound list not supported") } func (m *memoryOrganizationRepository) Collection() string { return mservice.Organizations } func (m *memoryOrganizationRepository) matchesQuery(query builder.Query, org *model.Organization) bool { if query == nil { return true } for _, elem := range query.BuildQuery() { switch elem.Key { case storable.IDField: id, ok := elem.Value.(primitive.ObjectID) if !ok || *org.GetID() != id { return false } case storable.IsArchivedField: value, ok := elem.Value.(bool) if ok && org.IsArchived() != value { return false } } } return true } func cloneOrganization(src *model.Organization) *model.Organization { dst := *src if len(src.Members) > 0 { dst.Members = append([]primitive.ObjectID{}, src.Members...) } return &dst }