mirror of
https://gitlab.com/suyu-emu/suyu.git
synced 2024-03-15 23:15:44 +00:00
ssl: remove ResultVal use
This commit is contained in:
parent
84cb20bc72
commit
83eee1d226
|
@ -54,7 +54,7 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
|
||||||
RegisterHandlers(functions);
|
RegisterHandlers(functions);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
|
static std::string ResolveImpl(const std::string& fqdn_in) {
|
||||||
// The real implementation makes various substitutions.
|
// The real implementation makes various substitutions.
|
||||||
// For now we just return the string as-is, which is good enough when not
|
// For now we just return the string as-is, which is good enough when not
|
||||||
// connecting to real Nintendo servers.
|
// connecting to real Nintendo servers.
|
||||||
|
@ -64,13 +64,10 @@ static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
|
||||||
|
|
||||||
static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
|
static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
|
||||||
const auto res = ResolveImpl(fqdn_in);
|
const auto res = ResolveImpl(fqdn_in);
|
||||||
if (res.Failed()) {
|
if (res.size() >= fqdn_out.size()) {
|
||||||
return res.Code();
|
|
||||||
}
|
|
||||||
if (res->size() >= fqdn_out.size()) {
|
|
||||||
return ResultOverflow;
|
return ResultOverflow;
|
||||||
}
|
}
|
||||||
std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
|
std::memcpy(fqdn_out.data(), res.c_str(), res.size() + 1);
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#include "common/string_util.h"
|
#include "common/string_util.h"
|
||||||
|
|
||||||
#include "core/core.h"
|
#include "core/core.h"
|
||||||
|
#include "core/hle/result.h"
|
||||||
#include "core/hle/service/ipc_helpers.h"
|
#include "core/hle/service/ipc_helpers.h"
|
||||||
#include "core/hle/service/server_manager.h"
|
#include "core/hle/service/server_manager.h"
|
||||||
#include "core/hle/service/service.h"
|
#include "core/hle/service/service.h"
|
||||||
|
@ -141,12 +142,12 @@ private:
|
||||||
bool did_set_host_name = false;
|
bool did_set_host_name = false;
|
||||||
bool did_handshake = false;
|
bool did_handshake = false;
|
||||||
|
|
||||||
ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
|
Result SetSocketDescriptorImpl(s32* out_fd, s32 fd) {
|
||||||
LOG_DEBUG(Service_SSL, "called, fd={}", fd);
|
LOG_DEBUG(Service_SSL, "called, fd={}", fd);
|
||||||
ASSERT(!did_handshake);
|
ASSERT(!did_handshake);
|
||||||
auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
|
auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
|
||||||
ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
|
||||||
s32 ret_fd;
|
|
||||||
// Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
|
// Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
|
||||||
if (do_not_close_socket) {
|
if (do_not_close_socket) {
|
||||||
auto res = bsd->DuplicateSocketImpl(fd);
|
auto res = bsd->DuplicateSocketImpl(fd);
|
||||||
|
@ -156,9 +157,9 @@ private:
|
||||||
}
|
}
|
||||||
fd = *res;
|
fd = *res;
|
||||||
fd_to_close = fd;
|
fd_to_close = fd;
|
||||||
ret_fd = fd;
|
*out_fd = fd;
|
||||||
} else {
|
} else {
|
||||||
ret_fd = -1;
|
*out_fd = -1;
|
||||||
}
|
}
|
||||||
std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
|
std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
|
||||||
if (!sock.has_value()) {
|
if (!sock.has_value()) {
|
||||||
|
@ -167,7 +168,7 @@ private:
|
||||||
}
|
}
|
||||||
socket = std::move(*sock);
|
socket = std::move(*sock);
|
||||||
backend->SetSocket(socket);
|
backend->SetSocket(socket);
|
||||||
return ret_fd;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result SetHostNameImpl(const std::string& hostname) {
|
Result SetHostNameImpl(const std::string& hostname) {
|
||||||
|
@ -247,34 +248,36 @@ private:
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<u8>> ReadImpl(size_t size) {
|
Result ReadImpl(std::vector<u8>* out_data, size_t size) {
|
||||||
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
||||||
std::vector<u8> res(size);
|
size_t actual_size{};
|
||||||
ResultVal<size_t> actual = backend->Read(res);
|
Result res = backend->Read(&actual_size, *out_data);
|
||||||
if (actual.Failed()) {
|
if (res != ResultSuccess) {
|
||||||
return actual.Code();
|
return res;
|
||||||
}
|
}
|
||||||
res.resize(*actual);
|
out_data->resize(actual_size);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> WriteImpl(std::span<const u8> data) {
|
Result WriteImpl(size_t* out_size, std::span<const u8> data) {
|
||||||
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
|
||||||
return backend->Write(data);
|
return backend->Write(out_size, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<s32> PendingImpl() {
|
Result PendingImpl(s32* out_pending) {
|
||||||
LOG_WARNING(Service_SSL, "(STUBBED) called.");
|
LOG_WARNING(Service_SSL, "(STUBBED) called.");
|
||||||
return 0;
|
*out_pending = 0;
|
||||||
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetSocketDescriptor(HLERequestContext& ctx) {
|
void SetSocketDescriptor(HLERequestContext& ctx) {
|
||||||
IPC::RequestParser rp{ctx};
|
IPC::RequestParser rp{ctx};
|
||||||
const s32 fd = rp.Pop<s32>();
|
const s32 in_fd = rp.Pop<s32>();
|
||||||
const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
|
s32 out_fd{-1};
|
||||||
|
const Result res = SetSocketDescriptorImpl(&out_fd, in_fd);
|
||||||
IPC::ResponseBuilder rb{ctx, 3};
|
IPC::ResponseBuilder rb{ctx, 3};
|
||||||
rb.Push(res.Code());
|
rb.Push(res);
|
||||||
rb.Push<s32>(res.ValueOr(-1));
|
rb.Push<s32>(out_fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetHostName(HLERequestContext& ctx) {
|
void SetHostName(HLERequestContext& ctx) {
|
||||||
|
@ -313,14 +316,15 @@ private:
|
||||||
};
|
};
|
||||||
static_assert(sizeof(OutputParameters) == 0x8);
|
static_assert(sizeof(OutputParameters) == 0x8);
|
||||||
|
|
||||||
const Result res = DoHandshakeImpl();
|
Result res = DoHandshakeImpl();
|
||||||
OutputParameters out{};
|
OutputParameters out{};
|
||||||
if (res == ResultSuccess) {
|
if (res == ResultSuccess) {
|
||||||
auto certs = backend->GetServerCerts();
|
std::vector<std::vector<u8>> certs;
|
||||||
if (certs.Succeeded()) {
|
res = backend->GetServerCerts(&certs);
|
||||||
const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
|
if (res == ResultSuccess) {
|
||||||
|
const std::vector<u8> certs_buf = SerializeServerCerts(certs);
|
||||||
ctx.WriteBuffer(certs_buf);
|
ctx.WriteBuffer(certs_buf);
|
||||||
out.certs_count = static_cast<u32>(certs->size());
|
out.certs_count = static_cast<u32>(certs.size());
|
||||||
out.certs_size = static_cast<u32>(certs_buf.size());
|
out.certs_size = static_cast<u32>(certs_buf.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -330,29 +334,32 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void Read(HLERequestContext& ctx) {
|
void Read(HLERequestContext& ctx) {
|
||||||
const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
|
std::vector<u8> output_bytes;
|
||||||
|
const Result res = ReadImpl(&output_bytes, ctx.GetWriteBufferSize());
|
||||||
IPC::ResponseBuilder rb{ctx, 3};
|
IPC::ResponseBuilder rb{ctx, 3};
|
||||||
rb.Push(res.Code());
|
rb.Push(res);
|
||||||
if (res.Succeeded()) {
|
if (res == ResultSuccess) {
|
||||||
rb.Push(static_cast<u32>(res->size()));
|
rb.Push(static_cast<u32>(output_bytes.size()));
|
||||||
ctx.WriteBuffer(*res);
|
ctx.WriteBuffer(output_bytes);
|
||||||
} else {
|
} else {
|
||||||
rb.Push(static_cast<u32>(0));
|
rb.Push(static_cast<u32>(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Write(HLERequestContext& ctx) {
|
void Write(HLERequestContext& ctx) {
|
||||||
const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
|
size_t write_size{0};
|
||||||
|
const Result res = WriteImpl(&write_size, ctx.ReadBuffer());
|
||||||
IPC::ResponseBuilder rb{ctx, 3};
|
IPC::ResponseBuilder rb{ctx, 3};
|
||||||
rb.Push(res.Code());
|
rb.Push(res);
|
||||||
rb.Push(static_cast<u32>(res.ValueOr(0)));
|
rb.Push(static_cast<u32>(write_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Pending(HLERequestContext& ctx) {
|
void Pending(HLERequestContext& ctx) {
|
||||||
const ResultVal<s32> res = PendingImpl();
|
s32 pending_size{0};
|
||||||
|
const Result res = PendingImpl(&pending_size);
|
||||||
IPC::ResponseBuilder rb{ctx, 3};
|
IPC::ResponseBuilder rb{ctx, 3};
|
||||||
rb.Push(res.Code());
|
rb.Push(res);
|
||||||
rb.Push<s32>(res.ValueOr(0));
|
rb.Push<s32>(pending_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetSessionCacheMode(HLERequestContext& ctx) {
|
void SetSessionCacheMode(HLERequestContext& ctx) {
|
||||||
|
@ -438,13 +445,14 @@ private:
|
||||||
void CreateConnection(HLERequestContext& ctx) {
|
void CreateConnection(HLERequestContext& ctx) {
|
||||||
LOG_WARNING(Service_SSL, "called");
|
LOG_WARNING(Service_SSL, "called");
|
||||||
|
|
||||||
auto backend_res = CreateSSLConnectionBackend();
|
std::unique_ptr<SSLConnectionBackend> backend;
|
||||||
|
const Result res = CreateSSLConnectionBackend(&backend);
|
||||||
|
|
||||||
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
|
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
|
||||||
rb.Push(backend_res.Code());
|
rb.Push(res);
|
||||||
if (backend_res.Succeeded()) {
|
if (res == ResultSuccess) {
|
||||||
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
|
rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
|
||||||
std::move(*backend_res));
|
std::move(backend));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,11 +35,11 @@ public:
|
||||||
virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
|
virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
|
||||||
virtual Result SetHostName(const std::string& hostname) = 0;
|
virtual Result SetHostName(const std::string& hostname) = 0;
|
||||||
virtual Result DoHandshake() = 0;
|
virtual Result DoHandshake() = 0;
|
||||||
virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
|
virtual Result Read(size_t* out_size, std::span<u8> data) = 0;
|
||||||
virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
|
virtual Result Write(size_t* out_size, std::span<const u8> data) = 0;
|
||||||
virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
|
virtual Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
|
Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend);
|
||||||
|
|
||||||
} // namespace Service::SSL
|
} // namespace Service::SSL
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
namespace Service::SSL {
|
namespace Service::SSL {
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
|
||||||
LOG_ERROR(Service_SSL,
|
LOG_ERROR(Service_SSL,
|
||||||
"Can't create SSL connection because no SSL backend is available on this platform");
|
"Can't create SSL connection because no SSL backend is available on this platform");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
|
|
|
@ -105,31 +105,30 @@ public:
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return HandleReturn("SSL_do_handshake", 0, ret).Code();
|
return HandleReturn("SSL_do_handshake", 0, ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
Result Read(size_t* out_size, std::span<u8> data) override {
|
||||||
size_t actual;
|
const int ret = SSL_read_ex(ssl, data.data(), data.size(), out_size);
|
||||||
const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
|
return HandleReturn("SSL_read_ex", out_size, ret);
|
||||||
return HandleReturn("SSL_read_ex", actual, ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
Result Write(size_t* out_size, std::span<const u8> data) override {
|
||||||
size_t actual;
|
const int ret = SSL_write_ex(ssl, data.data(), data.size(), out_size);
|
||||||
const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
|
return HandleReturn("SSL_write_ex", out_size, ret);
|
||||||
return HandleReturn("SSL_write_ex", actual, ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
|
Result HandleReturn(const char* what, size_t* actual, int ret) {
|
||||||
const int ssl_err = SSL_get_error(ssl, ret);
|
const int ssl_err = SSL_get_error(ssl, ret);
|
||||||
CheckOpenSSLErrors();
|
CheckOpenSSLErrors();
|
||||||
switch (ssl_err) {
|
switch (ssl_err) {
|
||||||
case SSL_ERROR_NONE:
|
case SSL_ERROR_NONE:
|
||||||
return actual;
|
return ResultSuccess;
|
||||||
case SSL_ERROR_ZERO_RETURN:
|
case SSL_ERROR_ZERO_RETURN:
|
||||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
|
||||||
// DoHandshake special-cases this, but for Read and Write:
|
// DoHandshake special-cases this, but for Read and Write:
|
||||||
return size_t(0);
|
*actual = 0;
|
||||||
|
return ResultSuccess;
|
||||||
case SSL_ERROR_WANT_READ:
|
case SSL_ERROR_WANT_READ:
|
||||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
|
||||||
return ResultWouldBlock;
|
return ResultWouldBlock;
|
||||||
|
@ -139,20 +138,20 @@ public:
|
||||||
default:
|
default:
|
||||||
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
|
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
|
||||||
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
|
||||||
return size_t(0);
|
*actual = 0;
|
||||||
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
|
LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
|
||||||
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
|
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
|
||||||
if (!chain) {
|
if (!chain) {
|
||||||
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
|
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
std::vector<std::vector<u8>> ret;
|
|
||||||
int count = sk_X509_num(chain);
|
int count = sk_X509_num(chain);
|
||||||
ASSERT(count >= 0);
|
ASSERT(count >= 0);
|
||||||
for (int i = 0; i < count; i++) {
|
for (int i = 0; i < count; i++) {
|
||||||
|
@ -161,10 +160,10 @@ public:
|
||||||
unsigned char* buf = nullptr;
|
unsigned char* buf = nullptr;
|
||||||
int len = i2d_X509(x509, &buf);
|
int len = i2d_X509(x509, &buf);
|
||||||
ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
|
ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
|
||||||
ret.emplace_back(buf, buf + len);
|
out_certs->emplace_back(buf, buf + len);
|
||||||
OPENSSL_free(buf);
|
OPENSSL_free(buf);
|
||||||
}
|
}
|
||||||
return ret;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
~SSLConnectionBackendOpenSSL() {
|
~SSLConnectionBackendOpenSSL() {
|
||||||
|
@ -253,13 +252,13 @@ public:
|
||||||
std::shared_ptr<Network::SocketBase> socket;
|
std::shared_ptr<Network::SocketBase> socket;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
|
||||||
auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
|
auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
|
||||||
const Result res = conn->Init();
|
|
||||||
if (res.IsFailure()) {
|
R_TRY(conn->Init());
|
||||||
return res;
|
|
||||||
}
|
*out_backend = std::move(conn);
|
||||||
return conn;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -299,21 +299,22 @@ public:
|
||||||
return ResultSuccess;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
Result Read(size_t* out_size, std::span<u8> data) override {
|
||||||
|
*out_size = 0;
|
||||||
if (handshake_state != HandshakeState::Connected) {
|
if (handshake_state != HandshakeState::Connected) {
|
||||||
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
|
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
if (data.size() == 0 || got_read_eof) {
|
if (data.size() == 0 || got_read_eof) {
|
||||||
return size_t(0);
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
while (1) {
|
while (1) {
|
||||||
if (!cleartext_read_buf.empty()) {
|
if (!cleartext_read_buf.empty()) {
|
||||||
const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
|
*out_size = std::min(cleartext_read_buf.size(), data.size());
|
||||||
std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
|
std::memcpy(data.data(), cleartext_read_buf.data(), *out_size);
|
||||||
cleartext_read_buf.erase(cleartext_read_buf.begin(),
|
cleartext_read_buf.erase(cleartext_read_buf.begin(),
|
||||||
cleartext_read_buf.begin() + read_size);
|
cleartext_read_buf.begin() + *out_size);
|
||||||
return read_size;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
if (!ciphertext_read_buf.empty()) {
|
if (!ciphertext_read_buf.empty()) {
|
||||||
SecBuffer empty{
|
SecBuffer empty{
|
||||||
|
@ -366,7 +367,8 @@ public:
|
||||||
case SEC_I_CONTEXT_EXPIRED:
|
case SEC_I_CONTEXT_EXPIRED:
|
||||||
// Server hung up by sending close_notify.
|
// Server hung up by sending close_notify.
|
||||||
got_read_eof = true;
|
got_read_eof = true;
|
||||||
return size_t(0);
|
*out_size = 0;
|
||||||
|
return ResultSuccess;
|
||||||
default:
|
default:
|
||||||
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
|
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
|
||||||
Common::NativeErrorToString(ret));
|
Common::NativeErrorToString(ret));
|
||||||
|
@ -379,18 +381,21 @@ public:
|
||||||
}
|
}
|
||||||
if (ciphertext_read_buf.empty()) {
|
if (ciphertext_read_buf.empty()) {
|
||||||
got_read_eof = true;
|
got_read_eof = true;
|
||||||
return size_t(0);
|
*out_size = 0;
|
||||||
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
Result Write(size_t* out_size, std::span<const u8> data) override {
|
||||||
|
*out_size = 0;
|
||||||
|
|
||||||
if (handshake_state != HandshakeState::Connected) {
|
if (handshake_state != HandshakeState::Connected) {
|
||||||
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
|
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
if (data.size() == 0) {
|
if (data.size() == 0) {
|
||||||
return size_t(0);
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
|
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
|
||||||
if (!cleartext_write_buf.empty()) {
|
if (!cleartext_write_buf.empty()) {
|
||||||
|
@ -402,7 +407,7 @@ public:
|
||||||
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
|
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
return WriteAlreadyEncryptedData();
|
return WriteAlreadyEncryptedData(out_size);
|
||||||
} else {
|
} else {
|
||||||
cleartext_write_buf.assign(data.begin(), data.end());
|
cleartext_write_buf.assign(data.begin(), data.end());
|
||||||
}
|
}
|
||||||
|
@ -448,21 +453,21 @@ public:
|
||||||
tmp_data_buf.end());
|
tmp_data_buf.end());
|
||||||
ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
|
||||||
trailer_buf.end());
|
trailer_buf.end());
|
||||||
return WriteAlreadyEncryptedData();
|
return WriteAlreadyEncryptedData(out_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> WriteAlreadyEncryptedData() {
|
Result WriteAlreadyEncryptedData(size_t* out_size) {
|
||||||
const Result r = FlushCiphertextWriteBuf();
|
const Result r = FlushCiphertextWriteBuf();
|
||||||
if (r != ResultSuccess) {
|
if (r != ResultSuccess) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
// write buf is empty
|
// write buf is empty
|
||||||
const size_t cleartext_bytes_written = cleartext_write_buf.size();
|
*out_size = cleartext_write_buf.size();
|
||||||
cleartext_write_buf.clear();
|
cleartext_write_buf.clear();
|
||||||
return cleartext_bytes_written;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
|
||||||
PCCERT_CONTEXT returned_cert = nullptr;
|
PCCERT_CONTEXT returned_cert = nullptr;
|
||||||
const SECURITY_STATUS ret =
|
const SECURITY_STATUS ret =
|
||||||
QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
||||||
|
@ -473,16 +478,15 @@ public:
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
PCCERT_CONTEXT some_cert = nullptr;
|
PCCERT_CONTEXT some_cert = nullptr;
|
||||||
std::vector<std::vector<u8>> certs;
|
|
||||||
while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
|
while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
|
||||||
certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
|
out_certs->emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
|
||||||
static_cast<u8*>(some_cert->pbCertEncoded) +
|
static_cast<u8*>(some_cert->pbCertEncoded) +
|
||||||
some_cert->cbCertEncoded);
|
some_cert->cbCertEncoded);
|
||||||
}
|
}
|
||||||
std::reverse(certs.begin(),
|
std::reverse(out_certs->begin(),
|
||||||
certs.end()); // Windows returns certs in reverse order from what we want
|
out_certs->end()); // Windows returns certs in reverse order from what we want
|
||||||
CertFreeCertificateContext(returned_cert);
|
CertFreeCertificateContext(returned_cert);
|
||||||
return certs;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
~SSLConnectionBackendSchannel() {
|
~SSLConnectionBackendSchannel() {
|
||||||
|
@ -532,13 +536,13 @@ public:
|
||||||
size_t read_buf_fill_size = 0;
|
size_t read_buf_fill_size = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
|
||||||
auto conn = std::make_unique<SSLConnectionBackendSchannel>();
|
auto conn = std::make_unique<SSLConnectionBackendSchannel>();
|
||||||
const Result res = conn->Init();
|
|
||||||
if (res.IsFailure()) {
|
R_TRY(conn->Init());
|
||||||
return res;
|
|
||||||
}
|
*out_backend = std::move(conn);
|
||||||
return conn;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Service::SSL
|
} // namespace Service::SSL
|
||||||
|
|
|
@ -103,24 +103,20 @@ public:
|
||||||
return HandleReturn("SSLHandshake", 0, status).Code();
|
return HandleReturn("SSLHandshake", 0, status).Code();
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Read(std::span<u8> data) override {
|
Result Read(size_t* out_size, std::span<u8> data) override {
|
||||||
size_t actual;
|
OSStatus status = SSLRead(context, data.data(), data.size(), &out_size);
|
||||||
OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
|
return HandleReturn("SSLRead", out_size, status);
|
||||||
;
|
|
||||||
return HandleReturn("SSLRead", actual, status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> Write(std::span<const u8> data) override {
|
Result Write(size_t* out_size, std::span<const u8> data) override {
|
||||||
size_t actual;
|
OSStatus status = SSLWrite(context, data.data(), data.size(), &out_size);
|
||||||
OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
|
return HandleReturn("SSLWrite", out_size, status);
|
||||||
;
|
|
||||||
return HandleReturn("SSLWrite", actual, status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
|
Result HandleReturn(const char* what, size_t* actual, OSStatus status) {
|
||||||
switch (status) {
|
switch (status) {
|
||||||
case 0:
|
case 0:
|
||||||
return actual;
|
return ResultSuccess;
|
||||||
case errSSLWouldBlock:
|
case errSSLWouldBlock:
|
||||||
return ResultWouldBlock;
|
return ResultWouldBlock;
|
||||||
default: {
|
default: {
|
||||||
|
@ -136,22 +132,21 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
Result GetServerCerts(std::vector<std::vector<u8>>* out_certs) override {
|
||||||
CFReleaser<SecTrustRef> trust;
|
CFReleaser<SecTrustRef> trust;
|
||||||
OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
|
OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
|
||||||
if (status) {
|
if (status) {
|
||||||
LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
|
LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
|
||||||
return ResultInternalError;
|
return ResultInternalError;
|
||||||
}
|
}
|
||||||
std::vector<std::vector<u8>> ret;
|
|
||||||
for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
|
for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
|
||||||
SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
|
SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
|
||||||
CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
|
CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
|
||||||
ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
|
ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
|
||||||
const u8* ptr = CFDataGetBytePtr(data);
|
const u8* ptr = CFDataGetBytePtr(data);
|
||||||
ret.emplace_back(ptr, ptr + CFDataGetLength(data));
|
out_certs->emplace_back(ptr, ptr + CFDataGetLength(data));
|
||||||
}
|
}
|
||||||
return ret;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
|
static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
|
||||||
|
@ -210,13 +205,13 @@ private:
|
||||||
std::shared_ptr<Network::SocketBase> socket;
|
std::shared_ptr<Network::SocketBase> socket;
|
||||||
};
|
};
|
||||||
|
|
||||||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
Result CreateSSLConnectionBackend(std::unique_ptr<SSLConnectionBackend>* out_backend) {
|
||||||
auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
|
auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
|
||||||
const Result res = conn->Init();
|
|
||||||
if (res.IsFailure()) {
|
R_TRY(conn->Init());
|
||||||
return res;
|
|
||||||
}
|
*out_backend = std::move(conn);
|
||||||
return conn;
|
return ResultSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Service::SSL
|
} // namespace Service::SSL
|
||||||
|
|
Loading…
Reference in a new issue