You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
522 lines
22 KiB
522 lines
22 KiB
// FIXME: Update this file to be null safe and then delete the line below |
|
#nullable disable |
|
|
|
using System.Data.Common; |
|
using AutoMapper; |
|
using AutoMapper.QueryableExtensions; |
|
using Bit.Core.AdminConsole.Enums.Provider; |
|
using Bit.Core.Billing.Constants; |
|
using Bit.Core.Billing.Enums; |
|
using Bit.Core.Enums; |
|
using Bit.Core.Models.Data.Organizations; |
|
using Bit.Core.Models.Data.Organizations.OrganizationUsers; |
|
using Bit.Core.Repositories; |
|
using LinqToDB.Tools; |
|
using Microsoft.EntityFrameworkCore; |
|
using Microsoft.Extensions.DependencyInjection; |
|
using Microsoft.Extensions.Logging; |
|
using Organization = Bit.Infrastructure.EntityFramework.AdminConsole.Models.Organization; |
|
|
|
namespace Bit.Infrastructure.EntityFramework.Repositories; |
|
|
|
public class OrganizationRepository : Repository<Core.AdminConsole.Entities.Organization, Organization, Guid>, IOrganizationRepository |
|
{ |
|
protected readonly ILogger<OrganizationRepository> _logger; |
|
|
|
public OrganizationRepository( |
|
IServiceScopeFactory serviceScopeFactory, |
|
IMapper mapper, |
|
ILogger<OrganizationRepository> logger) |
|
: base(serviceScopeFactory, mapper, context => context.Organizations) |
|
{ |
|
_logger = logger; |
|
} |
|
|
|
public async Task<Core.AdminConsole.Entities.Organization> GetByGatewayCustomerIdAsync(string gatewayCustomerId) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organization = await GetDbSet(dbContext) |
|
.Where(e => e.GatewayCustomerId == gatewayCustomerId) |
|
.FirstOrDefaultAsync(); |
|
return organization; |
|
} |
|
} |
|
|
|
public async Task<Core.AdminConsole.Entities.Organization> GetByGatewaySubscriptionIdAsync(string gatewaySubscriptionId) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organization = await GetDbSet(dbContext) |
|
.Where(e => e.GatewaySubscriptionId == gatewaySubscriptionId) |
|
.FirstOrDefaultAsync(); |
|
return organization; |
|
} |
|
} |
|
|
|
public async Task<Core.AdminConsole.Entities.Organization> GetByIdentifierAsync(string identifier) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organization = await GetDbSet(dbContext).Where(e => e.Identifier == identifier) |
|
.FirstOrDefaultAsync(); |
|
return organization; |
|
} |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetManyByEnabledAsync() |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organizations = await GetDbSet(dbContext).Where(e => e.Enabled).ToListAsync(); |
|
return Mapper.Map<List<Core.AdminConsole.Entities.Organization>>(organizations); |
|
} |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetManyByUserIdAsync(Guid userId) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organizations = await GetDbSet(dbContext) |
|
.SelectMany(e => e.OrganizationUsers |
|
.Where(ou => ou.UserId == userId)) |
|
.Include(ou => ou.Organization) |
|
.Select(ou => ou.Organization) |
|
.ToListAsync(); |
|
return Mapper.Map<List<Core.AdminConsole.Entities.Organization>>(organizations); |
|
} |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> SearchAsync(string name, string userEmail, |
|
bool? paid, int skip, int take) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organizations = await GetDbSet(dbContext) |
|
.Where(e => name == null || e.Name.Contains(name)) |
|
.Where(e => userEmail == null || e.OrganizationUsers.Any(u => u.Email == userEmail)) |
|
.Where(e => paid == null || |
|
(paid == true && !string.IsNullOrWhiteSpace(e.GatewaySubscriptionId)) || |
|
(paid == false && e.GatewaySubscriptionId == null)) |
|
.OrderBy(e => e.CreationDate) |
|
.Skip(skip).Take(take) |
|
.ToListAsync(); |
|
return Mapper.Map<List<Core.AdminConsole.Entities.Organization>>(organizations); |
|
} |
|
} |
|
|
|
public async Task<ICollection<OrganizationAbility>> GetManyAbilitiesAsync() |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
return await GetDbSet(dbContext) |
|
.Select(e => new OrganizationAbility |
|
{ |
|
Enabled = e.Enabled, |
|
Id = e.Id, |
|
Use2fa = e.Use2fa, |
|
UseEvents = e.UseEvents, |
|
UsersGetPremium = e.UsersGetPremium, |
|
Using2fa = e.Use2fa && e.TwoFactorProviders != null, |
|
UseSso = e.UseSso, |
|
UseKeyConnector = e.UseKeyConnector, |
|
UseResetPassword = e.UseResetPassword, |
|
UseScim = e.UseScim, |
|
UseCustomPermissions = e.UseCustomPermissions, |
|
UsePolicies = e.UsePolicies, |
|
LimitCollectionCreation = e.LimitCollectionCreation, |
|
LimitCollectionDeletion = e.LimitCollectionDeletion, |
|
LimitItemDeletion = e.LimitItemDeletion, |
|
AllowAdminAccessToAllCollectionItems = e.AllowAdminAccessToAllCollectionItems, |
|
UseRiskInsights = e.UseRiskInsights, |
|
UseOrganizationDomains = e.UseOrganizationDomains, |
|
UseAdminSponsoredFamilies = e.UseAdminSponsoredFamilies, |
|
UseAutomaticUserConfirmation = e.UseAutomaticUserConfirmation, |
|
UseDisableSmAdsForUsers = e.UseDisableSmAdsForUsers, |
|
UsePhishingBlocker = e.UsePhishingBlocker, |
|
UseMyItems = e.UseMyItems |
|
}).ToListAsync(); |
|
} |
|
} |
|
|
|
#nullable enable |
|
public async Task<OrganizationAbility?> GetAbilityAsync(Guid organizationId) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
|
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
return await GetDbSet(dbContext) |
|
.Where(e => e.Id == organizationId) |
|
.Select(e => new OrganizationAbility(e)) |
|
.SingleOrDefaultAsync(); |
|
} |
|
#nullable disable |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> SearchUnassignedToProviderAsync(string name, string ownerEmail, int skip, int take) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
|
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var disallowedPlanTypes = new List<PlanType> |
|
{ |
|
PlanType.Free, |
|
PlanType.Custom, |
|
PlanType.FamiliesAnnually2019, |
|
PlanType.FamiliesAnnually2025, |
|
PlanType.FamiliesAnnually |
|
}; |
|
|
|
var query = |
|
from o in dbContext.Organizations |
|
where o.PlanType.NotIn(disallowedPlanTypes) && |
|
!dbContext.ProviderOrganizations.Any(po => po.OrganizationId == o.Id) && |
|
(string.IsNullOrWhiteSpace(name) || EF.Functions.Like(o.Name, $"%{name}%")) |
|
select o; |
|
|
|
if (string.IsNullOrWhiteSpace(ownerEmail)) |
|
{ |
|
return await query.OrderByDescending(o => o.CreationDate) |
|
.Skip(skip) |
|
.Take(take) |
|
.ToArrayAsync(); |
|
} |
|
|
|
if (dbContext.Database.IsNpgsql()) |
|
{ |
|
query = from o in query |
|
join ou in dbContext.OrganizationUsers |
|
on o.Id equals ou.OrganizationId |
|
join u in dbContext.Users |
|
on ou.UserId equals u.Id |
|
where ou.Type == OrganizationUserType.Owner && EF.Functions.ILike(EF.Functions.Collate(u.Email, "default"), $"{ownerEmail}%") |
|
select o; |
|
} |
|
else |
|
{ |
|
query = from o in query |
|
join ou in dbContext.OrganizationUsers |
|
on o.Id equals ou.OrganizationId |
|
join u in dbContext.Users |
|
on ou.UserId equals u.Id |
|
where ou.Type == OrganizationUserType.Owner && EF.Functions.Like(u.Email, $"{ownerEmail}%") |
|
select o; |
|
} |
|
|
|
return await query.OrderByDescending(o => o.CreationDate).ThenByDescending(o => o.Id).Skip(skip).Take(take).ToArrayAsync(); |
|
} |
|
|
|
public async Task UpdateStorageAsync(Guid id) |
|
{ |
|
await OrganizationUpdateStorage(id); |
|
} |
|
|
|
public override async Task DeleteAsync(Core.AdminConsole.Entities.Organization organization) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organization.Id); |
|
var deleteCiphersTransaction = await dbContext.Database.BeginTransactionAsync(); |
|
await dbContext.Ciphers.Where(c => c.UserId == null && c.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await deleteCiphersTransaction.CommitAsync(); |
|
|
|
var organizationDeleteTransaction = await dbContext.Database.BeginTransactionAsync(); |
|
await dbContext.AuthRequests.Where(ar => ar.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.SsoUsers.Where(su => su.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.SsoConfigs.Where(sc => sc.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.CollectionUsers.Where(cu => cu.OrganizationUser.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.UserProjectAccessPolicy.Where(ap => ap.OrganizationUser.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.UserServiceAccountAccessPolicy.Where(ap => ap.OrganizationUser.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.UserSecretAccessPolicy.Where(ap => ap.OrganizationUser.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.OrganizationUsers.Where(ou => ou.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.ProviderOrganizations.Where(po => po.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.OrganizationIntegrations.Where(oi => oi.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
|
|
await dbContext.GroupServiceAccountAccessPolicy.Where(ap => ap.GrantedServiceAccount.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.Project.Where(p => p.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.Secret.Where(s => s.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.ApiKeys.Where(ak => ak.ServiceAccount.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.ServiceAccount.Where(sa => sa.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
|
|
await dbContext.NotificationStatuses.Where(ns => ns.Notification.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.Notifications.Where(n => n.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
|
|
// The below section are 3 SPROCS in SQL Server but are only called by here |
|
await dbContext.OrganizationApiKeys.Where(oa => oa.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
await dbContext.OrganizationConnections.Where(oc => oc.OrganizationId == organization.Id) |
|
.ExecuteDeleteAsync(); |
|
var sponsoringOrgs = await dbContext.OrganizationSponsorships |
|
.Where(os => os.SponsoringOrganizationId == organization.Id) |
|
.ToListAsync(); |
|
sponsoringOrgs.ForEach(os => os.SponsoringOrganizationId = null); |
|
var sponsoredOrgs = await dbContext.OrganizationSponsorships |
|
.Where(os => os.SponsoredOrganizationId == organization.Id) |
|
.ToListAsync(); |
|
sponsoredOrgs.ForEach(os => os.SponsoredOrganizationId = null); |
|
|
|
var orgEntity = await dbContext.FindAsync<Organization>(organization.Id); |
|
dbContext.Remove(orgEntity); |
|
|
|
await organizationDeleteTransaction.CommitAsync(); |
|
await dbContext.SaveChangesAsync(); |
|
} |
|
} |
|
|
|
public async Task<Core.AdminConsole.Entities.Organization> GetByLicenseKeyAsync(string licenseKey) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var organization = await GetDbSet(dbContext) |
|
.FirstOrDefaultAsync(o => o.LicenseKey == licenseKey); |
|
|
|
return organization; |
|
} |
|
} |
|
|
|
public async Task<SelfHostedOrganizationDetails> GetSelfHostedOrganizationDetailsById(Guid id) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var selfHostedOrganization = await dbContext.Organizations |
|
.Where(o => o.Id == id) |
|
.AsSplitQuery() |
|
.ProjectTo<SelfHostedOrganizationDetails>(Mapper.ConfigurationProvider) |
|
.SingleOrDefaultAsync(); |
|
|
|
return selfHostedOrganization; |
|
} |
|
} |
|
|
|
public async Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId) |
|
{ |
|
_logger.LogInformation("AC-1758: Executing GetOwnerEmailAddressesById (Entity Framework)"); |
|
|
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
|
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var query = |
|
from u in dbContext.Users |
|
join ou in dbContext.OrganizationUsers on u.Id equals ou.UserId |
|
where |
|
ou.OrganizationId == organizationId && |
|
ou.Type == OrganizationUserType.Owner && |
|
ou.Status == OrganizationUserStatusType.Confirmed |
|
group u by u.Email |
|
into grouped |
|
select grouped.Key; |
|
|
|
return await query.ToListAsync(); |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetByVerifiedUserEmailDomainAsync(Guid userId) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
|
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var userQuery = from u in dbContext.Users |
|
where u.Id == userId |
|
select u; |
|
|
|
var user = await userQuery.FirstOrDefaultAsync(); |
|
|
|
if (user is null) |
|
{ |
|
return new List<Core.AdminConsole.Entities.Organization>(); |
|
} |
|
|
|
var userWithDomain = new { UserId = user.Id, EmailDomain = user.Email.Split('@').Last() }; |
|
|
|
var query = from o in dbContext.Organizations |
|
join ou in dbContext.OrganizationUsers on o.Id equals ou.OrganizationId |
|
join od in dbContext.OrganizationDomains on ou.OrganizationId equals od.OrganizationId |
|
where ou.UserId == userWithDomain.UserId && |
|
od.DomainName == userWithDomain.EmailDomain && |
|
od.VerifiedDate != null && |
|
o.Enabled == true && |
|
ou.Status != OrganizationUserStatusType.Invited |
|
select o; |
|
|
|
return await query.ToArrayAsync(); |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetAddableToProviderByUserIdAsync( |
|
Guid userId, |
|
ProviderType providerType) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var planTypes = providerType switch |
|
{ |
|
ProviderType.Msp => PlanConstants.EnterprisePlanTypes.Concat(PlanConstants.TeamsPlanTypes), |
|
ProviderType.BusinessUnit => PlanConstants.EnterprisePlanTypes, |
|
_ => [] |
|
}; |
|
|
|
var query = |
|
from organizationUser in dbContext.OrganizationUsers |
|
join organization in dbContext.Organizations on organizationUser.OrganizationId equals organization.Id |
|
where |
|
organizationUser.UserId == userId && |
|
organizationUser.Type == OrganizationUserType.Owner && |
|
organizationUser.Status == OrganizationUserStatusType.Confirmed && |
|
organization.Enabled && |
|
organization.GatewayCustomerId != null && |
|
organization.GatewaySubscriptionId != null && |
|
organization.Seats > 0 && |
|
organization.Status == OrganizationStatusType.Created && |
|
!organization.UseSecretsManager && |
|
organization.PlanType.In(planTypes) |
|
select organization; |
|
|
|
return await query.ToArrayAsync(); |
|
} |
|
} |
|
|
|
public async Task<ICollection<Core.AdminConsole.Entities.Organization>> GetManyByIdsAsync(IEnumerable<Guid> ids) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
|
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var query = from organization in dbContext.Organizations |
|
where ids.Contains(organization.Id) |
|
select organization; |
|
|
|
return await query.ToArrayAsync(); |
|
} |
|
|
|
public async Task<OrganizationSeatCounts> GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) |
|
{ |
|
using (var scope = ServiceScopeFactory.CreateScope()) |
|
{ |
|
var dbContext = GetDatabaseContext(scope); |
|
var users = await dbContext.OrganizationUsers |
|
.Where(ou => ou.OrganizationId == organizationId && ou.Status >= 0) |
|
.CountAsync(); |
|
|
|
var sponsored = await dbContext.OrganizationSponsorships |
|
.Where(os => os.SponsoringOrganizationId == organizationId && |
|
os.IsAdminInitiated && |
|
(os.ToDelete == false || (os.ToDelete == true && os.ValidUntil != null && os.ValidUntil > DateTime.UtcNow)) && |
|
(os.SponsoredOrganizationId == null || (os.SponsoredOrganizationId != null && (os.ValidUntil == null || os.ValidUntil > DateTime.UtcNow)))) |
|
.CountAsync(); |
|
|
|
return new OrganizationSeatCounts |
|
{ |
|
Users = users, |
|
Sponsored = sponsored |
|
}; |
|
} |
|
} |
|
|
|
public async Task<IEnumerable<Core.AdminConsole.Entities.Organization>> GetOrganizationsForSubscriptionSyncAsync() |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
await using var dbContext = GetDatabaseContext(scope); |
|
|
|
var organizations = await dbContext.Organizations |
|
.Where(o => o.SyncSeats == true && o.Seats != null) |
|
.ToArrayAsync(); |
|
|
|
return organizations; |
|
} |
|
|
|
public async Task UpdateSuccessfulOrganizationSyncStatusAsync(IEnumerable<Guid> successfulOrganizations, DateTime syncDate) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
await using var dbContext = GetDatabaseContext(scope); |
|
|
|
await dbContext.Organizations |
|
.Where(o => successfulOrganizations.Contains(o.Id)) |
|
.ExecuteUpdateAsync(o => o |
|
.SetProperty(x => x.SyncSeats, false) |
|
.SetProperty(x => x.RevisionDate, syncDate.Date)); |
|
} |
|
|
|
public async Task IncrementSeatCountAsync(Guid organizationId, int increaseAmount, DateTime requestDate) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
await using var dbContext = GetDatabaseContext(scope); |
|
|
|
await dbContext.Organizations |
|
.Where(o => o.Id == organizationId) |
|
.ExecuteUpdateAsync(s => s |
|
.SetProperty(o => o.Seats, o => o.Seats + increaseAmount) |
|
.SetProperty(o => o.SyncSeats, true) |
|
.SetProperty(o => o.RevisionDate, requestDate)); |
|
} |
|
|
|
public async Task InitializeOrganizationAsync(Core.AdminConsole.Entities.Organization organization, Func<DbConnection, DbTransaction, Task> confirmOwnerAction) |
|
{ |
|
using var scope = ServiceScopeFactory.CreateScope(); |
|
var dbContext = GetDatabaseContext(scope); |
|
|
|
var connection = dbContext.Database.GetDbConnection(); |
|
await connection.OpenAsync(); |
|
await using var transaction = await connection.BeginTransactionAsync(); |
|
await dbContext.Database.UseTransactionAsync(transaction); |
|
|
|
try |
|
{ |
|
var efOrganization = await dbContext.Organizations.FindAsync(organization.Id); |
|
if (efOrganization is null) |
|
{ |
|
throw new InvalidOperationException($"Organization {organization.Id} was not found during initialization."); |
|
} |
|
|
|
efOrganization.Enabled = organization.Enabled; |
|
efOrganization.Status = organization.Status; |
|
efOrganization.PublicKey = organization.PublicKey; |
|
efOrganization.PrivateKey = organization.PrivateKey; |
|
efOrganization.RevisionDate = organization.RevisionDate; |
|
|
|
await dbContext.SaveChangesAsync(); |
|
|
|
await confirmOwnerAction(connection, transaction); |
|
|
|
await transaction.CommitAsync(); |
|
} |
|
catch (Exception ex) |
|
{ |
|
_logger.LogError(ex, |
|
"Failed to initialize organization. Rolling back transaction."); |
|
await transaction.RollbackAsync(); |
|
throw; |
|
} |
|
} |
|
}
|
|
|