diff --git a/src/KeyConnector/ServiceCollectionExtensions.cs b/src/KeyConnector/ServiceCollectionExtensions.cs index 16ee78d..e45e644 100644 --- a/src/KeyConnector/ServiceCollectionExtensions.cs +++ b/src/KeyConnector/ServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using System.Security.Claims; using Bit.KeyConnector.Repositories; using Bit.KeyConnector.Services; +using Bit.KeyConnector.Services.Pkcs11; using IdentityModel; using IdentityServer4.AccessTokenValidation; using JsonFlatFileDataStore; @@ -29,6 +30,7 @@ namespace Bit.KeyConnector else if (rsaKeyProvider == "pkcs11") { services.AddSingleton(); + services.AddSingleton(); } var certificateProvider = settings.Certificate.Provider?.ToLowerInvariant(); diff --git a/src/KeyConnector/Services/Pkcs11/IPkcs11InteropFactory.cs b/src/KeyConnector/Services/Pkcs11/IPkcs11InteropFactory.cs new file mode 100644 index 0000000..142d161 --- /dev/null +++ b/src/KeyConnector/Services/Pkcs11/IPkcs11InteropFactory.cs @@ -0,0 +1,26 @@ +using Net.Pkcs11Interop.Common; +using Net.Pkcs11Interop.HighLevelAPI; +using Net.Pkcs11Interop.HighLevelAPI.Factories; +using Net.Pkcs11Interop.HighLevelAPI.MechanismParams; + +namespace Bit.KeyConnector.Services.Pkcs11; + +public interface IPkcs11InteropFactory +{ + /// + IPkcs11Library LoadPkcs11Library(string libraryPath, AppType appType); + /// + ICkRsaPkcsOaepParams CreateCkRsaPkcsOaepParams(ulong hashAlg, ulong mgf, ulong source, byte[] sourceData); + /// + IMechanism CreateMechanism(CKM type); + /// + IMechanism CreateMechanism(CKM type, IMechanismParams parameters); + /// + public IObjectAttribute CreateObjectAttribute(CKA type, CKO value); + /// + public IObjectAttribute CreateObjectAttribute(CKA type, bool value); + /// + public IObjectAttribute CreateObjectAttribute(CKA type, ulong value); + /// + public IObjectAttribute CreateObjectAttribute(CKA type, string value); +} diff --git a/src/KeyConnector/Services/Pkcs11/Pkcs11InteropFactory.cs b/src/KeyConnector/Services/Pkcs11/Pkcs11InteropFactory.cs new file mode 100644 index 0000000..6250e97 --- /dev/null +++ b/src/KeyConnector/Services/Pkcs11/Pkcs11InteropFactory.cs @@ -0,0 +1,70 @@ +using System; +using System.Reflection; +using System.Runtime.InteropServices; +using Net.Pkcs11Interop.Common; +using Net.Pkcs11Interop.HighLevelAPI; +using Net.Pkcs11Interop.HighLevelAPI.MechanismParams; + +namespace Bit.KeyConnector.Services.Pkcs11; + +public class Pkcs11InteropFactory : IPkcs11InteropFactory +{ + private readonly Pkcs11InteropFactories _factories; + + public Pkcs11InteropFactory() + { + if (Platform.IsLinux) + { + // https://github.com/Pkcs11Interop/Pkcs11Interop/issues/239 + NativeLibrary.SetDllImportResolver(typeof(Pkcs11InteropFactories).Assembly, CustomDllImportResolver); + } + _factories = new Pkcs11InteropFactories(); + } + + public IPkcs11Library LoadPkcs11Library(string libraryPath, AppType appType) + { + return _factories.Pkcs11LibraryFactory.LoadPkcs11Library(_factories, libraryPath, appType); + } + + public ICkRsaPkcsOaepParams CreateCkRsaPkcsOaepParams(ulong hashAlg, ulong mgf, ulong source, byte[] sourceData) + { + return _factories.MechanismParamsFactory.CreateCkRsaPkcsOaepParams(hashAlg, mgf, source, sourceData); + } + + public IMechanism CreateMechanism(CKM type) + { + return _factories.MechanismFactory.Create(type); + } + + public IMechanism CreateMechanism(CKM type, IMechanismParams parameters) + { + return _factories.MechanismFactory.Create(type, parameters); + } + + public IObjectAttribute CreateObjectAttribute(CKA type, CKO value) + { + return _factories.ObjectAttributeFactory.Create(type, value); + } + + public IObjectAttribute CreateObjectAttribute(CKA type, bool value) + { + return _factories.ObjectAttributeFactory.Create(type, value); + } + + public IObjectAttribute CreateObjectAttribute(CKA type, ulong value) + { + return _factories.ObjectAttributeFactory.Create(type, value); + } + + public IObjectAttribute CreateObjectAttribute(CKA type, string value) + { + return _factories.ObjectAttributeFactory.Create(type, value); + } + + + private static IntPtr CustomDllImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? dllImportSearchPath) + { + var mappedLibraryName = (libraryName == "libdl") ? "libdl.so.2" : libraryName; + return NativeLibrary.Load(mappedLibraryName, assembly, dllImportSearchPath); + } +} diff --git a/src/KeyConnector/Services/Pkcs11RsaKeyService.cs b/src/KeyConnector/Services/Pkcs11/Pkcs11RsaKeyService.cs similarity index 63% rename from src/KeyConnector/Services/Pkcs11RsaKeyService.cs rename to src/KeyConnector/Services/Pkcs11/Pkcs11RsaKeyService.cs index 5895e64..696f13f 100644 --- a/src/KeyConnector/Services/Pkcs11RsaKeyService.cs +++ b/src/KeyConnector/Services/Pkcs11/Pkcs11RsaKeyService.cs @@ -1,16 +1,18 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Security.Cryptography.X509Certificates; using System.Threading.Tasks; using Net.Pkcs11Interop.Common; using Net.Pkcs11Interop.HighLevelAPI; -namespace Bit.KeyConnector.Services +namespace Bit.KeyConnector.Services.Pkcs11 { public class Pkcs11RsaKeyService : IRsaKeyService { private readonly ICertificateProviderService _certificateProviderService; private readonly ICryptoFunctionService _cryptoFunctionService; + private readonly IPkcs11InteropFactory _pkcs11InteropFactory; private readonly KeyConnectorSettings _settings; private X509Certificate2 _certificate; @@ -18,10 +20,12 @@ namespace Bit.KeyConnector.Services public Pkcs11RsaKeyService( ICertificateProviderService certificateProviderService, ICryptoFunctionService cryptoFunctionService, + IPkcs11InteropFactory pkcs11LibraryFactory, KeyConnectorSettings settings) { _certificateProviderService = certificateProviderService; _cryptoFunctionService = cryptoFunctionService; + _pkcs11InteropFactory = pkcs11LibraryFactory; _settings = settings; } @@ -39,19 +43,19 @@ namespace Bit.KeyConnector.Services { if (data == null) { - return null; + return Task.FromResult(null); } using var library = LoadLibrary(); using var session = CreateNewSession(library); var privateKey = GetPrivateKey(session); - var mechanismParams = session.Factories.MechanismParamsFactory.CreateCkRsaPkcsOaepParams( + var mechanismParams = _pkcs11InteropFactory.CreateCkRsaPkcsOaepParams( ConvertUtils.UInt64FromCKM(CKM.CKM_SHA_1), ConvertUtils.UInt64FromCKG(CKG.CKG_MGF1_SHA1), ConvertUtils.UInt64FromUInt32(CKZ.CKZ_DATA_SPECIFIED), null); - var mechanism = session.Factories.MechanismFactory.Create(CKM.CKM_RSA_PKCS_OAEP, mechanismParams); + var mechanism = _pkcs11InteropFactory.CreateMechanism(CKM.CKM_RSA_PKCS_OAEP, mechanismParams); var plainData = session.Decrypt(mechanism, privateKey, data); session.Logout(); @@ -62,14 +66,14 @@ namespace Bit.KeyConnector.Services { if (data == null) { - return null; + return Task.FromResult(null); } using var library = LoadLibrary(); using var session = CreateNewSession(library); var privateKey = GetPrivateKey(session); - var mechanism = session.Factories.MechanismFactory.Create(CKM.CKM_SHA256_RSA_PKCS); + var mechanism = _pkcs11InteropFactory.CreateMechanism(CKM.CKM_SHA256_RSA_PKCS); var signature = session.Sign(mechanism, privateKey, data); session.Logout(); @@ -105,104 +109,97 @@ namespace Bit.KeyConnector.Services { var attributes = new List { - session.Factories.ObjectAttributeFactory.Create(CKA.CKA_CLASS, CKO.CKO_PRIVATE_KEY), - session.Factories.ObjectAttributeFactory.Create(CKA.CKA_TOKEN, true) + _pkcs11InteropFactory.CreateObjectAttribute(CKA.CKA_CLASS, CKO.CKO_PRIVATE_KEY), + _pkcs11InteropFactory.CreateObjectAttribute(CKA.CKA_TOKEN, true) }; if (_settings.RsaKey.Pkcs11PrivateKeyId.HasValue) { - attributes.Add(session.Factories.ObjectAttributeFactory.Create(CKA.CKA_ID, + attributes.Add(_pkcs11InteropFactory.CreateObjectAttribute(CKA.CKA_ID, _settings.RsaKey.Pkcs11PrivateKeyId.Value)); } else { - attributes.Add(session.Factories.ObjectAttributeFactory.Create(CKA.CKA_LABEL, + attributes.Add(_pkcs11InteropFactory.CreateObjectAttribute(CKA.CKA_LABEL, _settings.RsaKey.Pkcs11PrivateKeyLabel)); } var objects = session.FindAllObjects(attributes); - if (objects.Count == 0) - { - throw new System.Exception("Private key not found."); - } - else if (objects.Count > 1) + return objects.Count switch { - throw new System.Exception("More than one private key was found. Use a more specific identifier."); - } - - return objects.Single(); + 0 => throw new System.Exception("Private key not found."), + > 1 => throw new System.Exception( + "More than one private key was found. Use a more specific identifier."), + _ => objects.Single() + }; } + private IPkcs11Library LoadLibrary() { var libPath = _settings.RsaKey.Pkcs11LibraryPath; if (string.IsNullOrWhiteSpace(libPath)) { var provider = _settings.RsaKey.Pkcs11Provider?.ToLowerInvariant(); - if (provider == "yubihsm") - { - libPath = "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so"; - } - else if (provider == "opensc") - { - libPath = "/usr/lib/x86_64-linux-gnu/opensc-pkcs11.so"; - } - else + libPath = provider switch { - throw new System.Exception("Please provide a library path or known provider."); - } + "yubihsm" => "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so", + "opensc" => "/usr/lib/x86_64-linux-gnu/opensc-pkcs11.so", + _ => throw new Exception("Please provide a library path or known provider.") + }; } - var factories = new Pkcs11InteropFactories(); - var library = factories.Pkcs11LibraryFactory.LoadPkcs11Library(factories, libPath, AppType.MultiThreaded); + var library = _pkcs11InteropFactory.LoadPkcs11Library(libPath, AppType.MultiThreaded); if (library == null) { - throw new System.Exception("Cannot load library."); + throw new Exception("Cannot load library."); } + return library; } private ISession CreateNewSession(IPkcs11Library library) { + var slotTokenSerialNumber = _settings.RsaKey.Pkcs11SlotTokenSerialNumber?.ToLowerInvariant(); + var userTypeSetting = _settings.RsaKey.Pkcs11LoginUserType?.ToLowerInvariant(); + var loginPin = _settings.RsaKey.Pkcs11LoginPin; + ISlot chosenSlot = null; var slots = library.GetSlotList(SlotsType.WithOrWithoutTokenPresent); - var serialNumber = _settings.RsaKey.Pkcs11SlotTokenSerialNumber?.ToLowerInvariant(); foreach (var slot in slots) { var slotInfo = slot.GetSlotInfo(); - if (slotInfo.SlotFlags.TokenPresent) + if (!slotInfo.SlotFlags.TokenPresent) + { + continue; + } + + try { - try + var tokenInfo = slot.GetTokenInfo(); + if (tokenInfo?.SerialNumber?.ToLowerInvariant() == slotTokenSerialNumber) { - var tokenInfo = slot.GetTokenInfo(); - if (tokenInfo?.SerialNumber?.ToLowerInvariant() == serialNumber) - { - chosenSlot = slot; - break; - } + chosenSlot = slot; + break; } - catch (Pkcs11Exception) { } } + catch (Pkcs11Exception) {} } if (chosenSlot == null) { - throw new System.Exception("Cannot locate token slot."); + throw new Exception("Cannot locate token slot."); } // TODO: read only? var session = chosenSlot.OpenSession(SessionType.ReadWrite); - var userType = CKU.CKU_USER; - var userTypeSetting = _settings.RsaKey.Pkcs11LoginUserType?.ToLowerInvariant(); - if (userTypeSetting == "so") - { - userType = CKU.CKU_SO; - } - else if (userTypeSetting == "context_specific") + var userType = userTypeSetting switch { - userType = CKU.CKU_CONTEXT_SPECIFIC; - } - session.Login(userType, _settings.RsaKey.Pkcs11LoginPin); + "so" => CKU.CKU_SO, + "context_specific" => CKU.CKU_CONTEXT_SPECIFIC, + _ => CKU.CKU_USER + }; + session.Login(userType, loginPin); return session; } } diff --git a/test/KeyConnector.Tests/KeyConnector.Tests.csproj b/test/KeyConnector.Tests/KeyConnector.Tests.csproj index 394823c..4633b81 100644 --- a/test/KeyConnector.Tests/KeyConnector.Tests.csproj +++ b/test/KeyConnector.Tests/KeyConnector.Tests.csproj @@ -12,6 +12,7 @@ all + all diff --git a/test/KeyConnector.Tests/Services/Pkcs11RsaKeyServiceTests.cs b/test/KeyConnector.Tests/Services/Pkcs11RsaKeyServiceTests.cs new file mode 100644 index 0000000..2f6ce01 --- /dev/null +++ b/test/KeyConnector.Tests/Services/Pkcs11RsaKeyServiceTests.cs @@ -0,0 +1,571 @@ +using Bit.KeyConnector.Services; +using Xunit; +using NSubstitute; +using Bit.KeyConnector; +using System.Threading.Tasks; +using Net.Pkcs11Interop.HighLevelAPI; +using Net.Pkcs11Interop.Common; +using System; +using System.Linq; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Bit.KeyConnector.Services.Pkcs11; +using Net.Pkcs11Interop.HighLevelAPI.MechanismParams; + +namespace KeyConnector.Tests.Services; + +internal delegate void InitializationModifier( + ICertificateProviderService certificateProviderService, + ICryptoFunctionService cryptoFunctionService, + IPkcs11InteropFactory pkcs11InteropFactory, + KeyConnectorSettings keyConnectorSettings +); + +public class Pkcs11RsaKeyServiceTests +{ + private readonly byte[] _data = "data"u8.ToArray(); + private readonly byte[] _encryptedData = "encryptedData"u8.ToArray(); + private readonly byte[] _decryptedData = "decryptedData"u8.ToArray(); + private readonly byte[] _signature = "signature"u8.ToArray(); + + private static Pkcs11RsaKeyService InitializeService(InitializationModifier modifier = null) + { + var certificateProviderService = Substitute.For(); + var cryptoFunctionService = Substitute.For(); + var pkcs11InteropFactory = Substitute.For(); + var settings = new KeyConnectorSettings(); + modifier?.Invoke(certificateProviderService, cryptoFunctionService, pkcs11InteropFactory, settings); + + return new Pkcs11RsaKeyService(certificateProviderService, cryptoFunctionService, pkcs11InteropFactory, + settings); + } + + // EncryptAsync Tests + + [Fact] + public async Task EncryptAsync_ReturnsEncryptedData() + { + // Create certificate + var cert = CreateCertificate(); + var publicKey = cert.GetRSAPublicKey()!.ExportSubjectPublicKeyInfo(); + + // Initialize sut + var sut = InitializeService((certProvider, cryptoService, _, _) => + { + certProvider.GetCertificateAsync().Returns(cert); + cryptoService.RsaEncryptAsync( + _data, + Arg.Is(input => publicKey.SequenceEqual(input))) + .Returns(_encryptedData); + }); + + // Act + var result = await sut.EncryptAsync(_data); + + // Assert + Assert.Equal("encryptedData"u8.ToArray(), result); + } + + [Fact] + public async Task EncryptAsync_ReturnsNull_GivenNullData() + { + var sut = InitializeService(); + Assert.Null(await sut.EncryptAsync(null)); + } + + // DecryptAsync Tests + + [Fact] + public async Task DecryptAsync_ReturnsDecryptedData() + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + var privateKey = session.AddPrivateKey(); + + var mechanismParams = Substitute.For(); + var mechanism = Substitute.For(); + + + session.Decrypt(mechanism, privateKey, _data).Returns(_decryptedData); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + interopFactory.CreateCkRsaPkcsOaepParams(default, default, default, default) + .ReturnsForAnyArgs(mechanismParams); + interopFactory.CreateMechanism(Arg.Any(), mechanismParams).Returns(mechanism); + }); + + // Act + var result = await sut.DecryptAsync(_data); + + // Assert + Assert.Equal("decryptedData"u8.ToArray(), result); + } + + [Fact] + public async Task DecryptAsync_LogsOutOfSession_GivenValidData() + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + var privateKey = session.AddPrivateKey(); + + var mechanismParams = Substitute.For(); + var mechanism = Substitute.For(); + + session.Decrypt(mechanism, privateKey, _data).Returns(_decryptedData); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + interopFactory.CreateCkRsaPkcsOaepParams(default, default, default, default) + .ReturnsForAnyArgs(mechanismParams); + interopFactory.CreateMechanism(Arg.Any(), mechanismParams).Returns(mechanism); + }); + + // Act + await sut.DecryptAsync(_data); + + // Assert + session.Received().Logout(); + } + + [Theory] + [InlineData("yubihsm", "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so")] + [InlineData("opensc", "/usr/lib/x86_64-linux-gnu/opensc-pkcs11.so")] + public async Task DecryptAsync_UsesCorrectProviderPath_WhenLoadingLibrary(string provider, string path) + { + // Create mocks + var library = Substitute.For(); + IPkcs11InteropFactory pkcs11InteropFactory = null; + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings { Pkcs11Provider = provider }; + interopFactory.LoadPkcs11Library(path, Arg.Any()).Returns(library); + pkcs11InteropFactory = interopFactory; + }); + + // Act + try + { + await sut.DecryptAsync(_data); + } + catch + { + } + + // Assert + pkcs11InteropFactory.Received().LoadPkcs11Library(path, Arg.Any()); + } + + [Fact] + public async Task DecryptAsync_ThrowsException_WhenSlotHasNoTokenPresent() + { + // Create mocks + var library = Substitute.For(); + library.AddSlot("chosenSerialNumber", false); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + }); + + // Assert + var exception = await Assert.ThrowsAsync(async () => await sut.DecryptAsync(_data)); + Assert.Contains("Cannot locate token slot.", exception.Message); + } + + [Fact] + public async Task DecryptAsync_ChoosesCorrectSlot_GivenASlotSerialNumber() + { + // Create mocks + var library = Substitute.For(); + + var slot1 = Substitute.For(); + slot1.GetSlotInfo().SlotFlags.TokenPresent.Returns(true); + slot1.GetTokenInfo().SerialNumber.Returns("chosenSerialNumber"); + var slot2 = Substitute.For(); + slot2.GetSlotInfo().SlotFlags.TokenPresent.Returns(true); + slot2.GetTokenInfo().SerialNumber.Returns("wrongSerialNumber"); + library.GetSlotList(default).ReturnsForAnyArgs([slot1, slot2]); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + }); + + // Act + try + { + await sut.DecryptAsync(_data); + } + catch + { + } + + // Assert + slot1.Received().OpenSession(Arg.Any()); + } + + [Theory] + [InlineData(null, CKU.CKU_USER)] + [InlineData("so", CKU.CKU_SO)] + [InlineData("context_specific", CKU.CKU_CONTEXT_SPECIFIC)] + public async Task DecryptAsync_UsesCorrectUserType_GivenInSettings(string userTypeSetting, CKU expectedUserType) + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = userTypeSetting, + Pkcs11LoginPin = "loginPin" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + }); + + // Act + try + { + await sut.DecryptAsync(_data); + } + catch + { + } + + // Assert + session.Received().Login(expectedUserType, "loginPin"); + } + + [Theory] + [InlineData(123UL, CKA.CKA_ID)] + [InlineData(null, CKA.CKA_LABEL)] + public async Task DecryptAsync_UsesPrivateKeyId_IfAvailableInSettings(ulong? id, CKA expectedType) + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + session.AddPrivateKey(); + + IPkcs11InteropFactory pkcs11InteropFactory = null; + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin", + Pkcs11PrivateKeyId = id, + Pkcs11PrivateKeyLabel = "label" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + pkcs11InteropFactory = interopFactory; + }); + + // Act + await sut.DecryptAsync(_data); + + // Assert + if (id is not null) + { + pkcs11InteropFactory.Received().CreateObjectAttribute(expectedType, id.Value); + } + else + { + pkcs11InteropFactory.Received().CreateObjectAttribute(expectedType, "label"); + } + } + + [Fact] + public async Task DecryptAsync_ReturnsNull_GivenNullData() + { + var sut = InitializeService(); + Assert.Null(await sut.DecryptAsync(null)); + } + + // SignAsync Tests + + [Fact] + public async Task SignAsync_ReturnsSignature() + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + var privateKey = session.AddPrivateKey(); + var mechanism = Substitute.For(); + + session.Sign(mechanism, privateKey, _data).Returns(_signature); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + interopFactory.CreateMechanism(Arg.Any()).Returns(mechanism); + }); + + // Act + var result = await sut.SignAsync(_data); + + // Assert + Assert.Equal("signature"u8.ToArray(), result); + } + + [Fact] + public async Task SignAsync_LogsOutOfSession_GivenValidData() + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + var privateKey = session.AddPrivateKey(); + var mechanism = Substitute.For(); + + session.Sign(mechanism, privateKey, _data).Returns(_signature); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin" + }; + + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + interopFactory.CreateMechanism(Arg.Any()).Returns(mechanism); + }); + + // Act + await sut.SignAsync(_data); + + // Assert + session.Received().Logout(); + } + + [Theory] + [InlineData("yubihsm", "/usr/lib/x86_64-linux-gnu/pkcs11/yubihsm_pkcs11.so")] + [InlineData("opensc", "/usr/lib/x86_64-linux-gnu/opensc-pkcs11.so")] + public async Task SignAsync_UsesCorrectProviderPath_WhenLoadingLibrary(string provider, string path) + { + // Create mocks + var library = Substitute.For(); + IPkcs11InteropFactory pkcs11InteropFactory = null; + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings { Pkcs11Provider = provider }; + interopFactory.LoadPkcs11Library(path, Arg.Any()).Returns(library); + pkcs11InteropFactory = interopFactory; + }); + + // Act + try + { + await sut.SignAsync(_data); + } + catch + { + } + + // Assert + pkcs11InteropFactory.Received().LoadPkcs11Library(path, Arg.Any()); + } + + [Fact] + public async Task SignAsync_ThrowsException_WhenSlotHasNoTokenPresent() + { + // Create mocks + var library = Substitute.For(); + library.AddSlot("chosenSerialNumber", false); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + }); + + // Assert + var exception = await Assert.ThrowsAsync(async () => await sut.SignAsync(_data)); + Assert.Contains("Cannot locate token slot.", exception.Message); + } + + [Fact] + public async Task SignAsync_ChoosesCorrectSlot_GivenASlotSerialNumber() + { + // Create mocks + var library = Substitute.For(); + + var slot1 = Substitute.For(); + slot1.GetSlotInfo().SlotFlags.TokenPresent.Returns(true); + slot1.GetTokenInfo().SerialNumber.Returns("chosenSerialNumber"); + var slot2 = Substitute.For(); + slot2.GetSlotInfo().SlotFlags.TokenPresent.Returns(true); + slot2.GetTokenInfo().SerialNumber.Returns("wrongSerialNumber"); + library.GetSlotList(default).ReturnsForAnyArgs([slot1, slot2]); + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + }); + + // Act + try + { + await sut.SignAsync(_data); + } + catch + { + } + + // Assert + slot1.Received().OpenSession(Arg.Any()); + } + + [Theory] + [InlineData(123UL, CKA.CKA_ID)] + [InlineData(null, CKA.CKA_LABEL)] + public async Task SignAsync_UsesPrivateKeyId_IfAvailableInSettings(ulong? id, CKA expectedType) + { + // Create mocks + var library = Substitute.For(); + var slot = library.AddSlot("chosenSerialNumber"); + var session = slot.AddSession(); + session.AddPrivateKey(); + + IPkcs11InteropFactory pkcs11InteropFactory = null; + + // Initialize sut + var sut = InitializeService((_, _, interopFactory, settings) => + { + settings.RsaKey = new KeyConnectorSettings.RsaKeySettings + { + Pkcs11Provider = "yubihsm", + Pkcs11SlotTokenSerialNumber = "chosenSerialNumber", + Pkcs11LoginUserType = "userType", + Pkcs11LoginPin = "loginPin", + Pkcs11PrivateKeyId = id, + Pkcs11PrivateKeyLabel = "label" + }; + interopFactory.LoadPkcs11Library(default, default).ReturnsForAnyArgs(library); + pkcs11InteropFactory = interopFactory; + }); + + // Act + await sut.SignAsync(_data); + + // Assert + if (id is not null) + { + pkcs11InteropFactory.Received().CreateObjectAttribute(expectedType, id.Value); + } + else + { + pkcs11InteropFactory.Received().CreateObjectAttribute(expectedType, "label"); + } + } + + [Fact] + public async Task SignAsync_ReturnsNull_GivenNullData() + { + var sut = InitializeService(); + Assert.Null(await sut.SignAsync(null)); + } + + private static X509Certificate2 CreateCertificate() + { + var rsa = RSA.Create(2048); + var req = new CertificateRequest("CN=Experimental Issuing Authority", rsa, HashAlgorithmName.SHA256, + RSASignaturePadding.Pkcs1); + var cert = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5)); + return cert; + } +} + +internal static class Extensions +{ + public static ISlot AddSlot(this IPkcs11Library library, string serialNumber, bool tokenPresent = true) + { + var slot = Substitute.For(); + slot.GetSlotInfo().SlotFlags.TokenPresent.Returns(tokenPresent); + slot.GetTokenInfo().SerialNumber.Returns(serialNumber); + library.GetSlotList(default).ReturnsForAnyArgs([slot]); + return slot; + } + + public static ISession AddSession(this ISlot slot) + { + var session = Substitute.For(); + slot.OpenSession(default).ReturnsForAnyArgs(session); + return session; + } + + public static IObjectHandle AddPrivateKey(this ISession session) + { + var privateKey = Substitute.For(); + session.FindAllObjects(default).ReturnsForAnyArgs([privateKey]); + return privateKey; + } +}