Commit 96cdcffc authored by jan.koester's avatar jan.koester
Browse files

added sni support

parent 7e7f40a3
Loading
Loading
Loading
Loading
+11 −5
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@
#include <string>
#include <memory>
#include <deque>
#include <map>
#include <cstring>
#include <algorithm>

@@ -307,20 +308,20 @@ namespace netplus {
	class ssl : public tcp {
	public:
		
		struct CertvificateBundle {
		struct CertificateBundle {
			netplus::x509cert    cert;
			std::vector<uint8_t> privateKeyDer;
		};

		using tcp::operator=;

		ssl(const netplus::x509cert& cert);
		ssl(const netplus::x509cert& cert, int sock);
		ssl(const std::map<std::string, CertificateBundle>& certs);
		ssl(const std::map<std::string, CertificateBundle>& certs, int sock);
		ssl(std::shared_ptr<netplus::x509cert> cert, int sock);

		virtual ~ssl() = default;

		ssl(const netplus::x509cert& cert,const std::string& addr, int port, int maxconnections, int sockopts);
		ssl(const std::map<std::string, CertificateBundle>& certs, const std::string& addr, int port, int maxconnections, int sockopts);

		void accept(std::unique_ptr<socket>& csock,bool nonblock) override;
		void handshake_after_accept() override;
@@ -536,6 +537,11 @@ namespace netplus {
		bool     _cli_cke_sent = false;  // TLS 1.2 client: CKE already sent
		bool     _cli_ccs_fin_sent = false;  // TLS 1.2 client: CCS+Finished sent

		// --- SNI (Server Name Indication) Support ---
		std::map<std::string, CertificateBundle> _cert_map;  // domain -> CertificateBundle mapping
		std::string _requested_hostname;  // Hostname from client SNI extension
		CertificateBundle* _selected_cert_bundle = nullptr;  // Pointer to selected bundle from map

		std::unique_ptr<netplus::aes> _aes = nullptr;
		std::unique_ptr<netplus::aes> _aes_recv = nullptr;

@@ -643,7 +649,7 @@ namespace netplus {
        uint8_t 						  _client_priv_ecdhe[32] = {0};       // client ephemeral scalar (BE)
        uint8_t 						  _server_priv_ecdhe[32] = {0};
		std::vector<uint8_t> _ecdhe_shared;     // 32 bytes shared secret
// server ephemeral scalar (BE)

		enum class HsState {
			// server
			READ_CLIENT_HELLO,
+126 −6
Original line number Diff line number Diff line
@@ -322,30 +322,134 @@ namespace netplus {

};

netplus::ssl::ssl(const netplus::x509cert &cert) :
// Hilfsfunktion zum Parsen von SNI (server_name) Extension aus ClientHello
static bool extractSNIFromClientHello(const std::vector<uint8_t>& ch, std::string& out_hostname) {
    if (ch.size() < 44) return false;  // Minimum size for ClientHello
    if (ch[0] != 0x01) return false;   // Type must be ClientHello

    auto readU16 = [&](size_t& p) -> uint16_t {
        if (p + 2 > ch.size()) return 0;
        uint16_t x = (uint16_t(ch[p]) << 8) | ch[p+1];
        p += 2;
        return x;
    };
    auto readU8 = [&](size_t& p) -> uint8_t {
        if (p + 1 > ch.size()) return 0;
        return ch[p++];
    };

    size_t p = 4;  // Skip type (1) and length (3)

    // legacy_version (2 bytes)
    (void)readU16(p);

    // random (32 bytes)
    if (p + 32 > ch.size()) return false;
    p += 32;

    // session_id
    uint8_t sidLen = readU8(p);
    if (p + sidLen > ch.size()) return false;
    p += sidLen;

    // cipher_suites
    uint16_t csLen = readU16(p);
    if (csLen < 2 || (csLen % 2) != 0) return false;
    if (p + csLen > ch.size()) return false;
    p += csLen;

    // compression_methods
    uint8_t compLen = readU8(p);
    if (compLen < 1) return false;
    if (p + compLen > ch.size()) return false;
    p += compLen;

    // extensions
    if (p + 2 > ch.size()) return false;
    uint16_t extLen = readU16(p);
    if (p + extLen > ch.size()) return false;

    size_t eend = p + extLen;

    while (p + 4 <= eend) {
        uint16_t et = (uint16_t(ch[p]) << 8) | ch[p+1];
        uint16_t el = (uint16_t(ch[p+2]) << 8) | ch[p+3];
        p += 4;

        if (p + el > eend) return false;

        // server_name extension (type 0x0000)
        if (et == 0x0000) {
            size_t ep = p;
            size_t eend_ext = p + el;

            // server_name_list_length
            if (ep + 2 > eend_ext) return false;
            uint16_t snLen = (uint16_t(ch[ep]) << 8) | ch[ep+1];
            ep += 2;

            if (ep + snLen > eend_ext) return false;

            // server_name (first entry)
            if (ep + 3 > eend_ext) return false;
            uint8_t nameType = ch[ep];  // should be 0 for host_name
            ep++;

            uint16_t nameLen = (uint16_t(ch[ep]) << 8) | ch[ep+1];
            ep += 2;

            if (nameType == 0 && ep + nameLen <= eend_ext) {
                out_hostname.assign((const char*)ch.data() + ep, nameLen);
                return true;
            }
            return false;
        }

        p += el;
    }

    return false;  // SNI not found
}

netplus::ssl::ssl(const std::map<std::string, CertificateBundle>& certs) :
    tcp(),
    _cert_map(certs),
    _aes(nullptr),
    _handshakeDone(false)
{
    _cert=std::make_shared<x509cert >(cert);
    // Wähle das erste Zertifikat als Standard aus
    if (!certs.empty()) {
        _selected_cert_bundle = const_cast<CertificateBundle*>(&certs.begin()->second);
        _cert = std::make_shared<x509cert>(_selected_cert_bundle->cert);
    }
    _Type=sockettype::SSL;
}

netplus::ssl::ssl(const netplus::x509cert &cert,int sock) :
netplus::ssl::ssl(const std::map<std::string, CertificateBundle>& certs, int sock) :
    tcp(sock),
    _cert_map(certs),
    _aes(nullptr),
    _handshakeDone(false)
{
    _cert=std::make_shared<x509cert >(cert);
    // Wähle das erste Zertifikat als Standard aus
    if (!certs.empty()) {
        _selected_cert_bundle = const_cast<CertificateBundle*>(&certs.begin()->second);
        _cert = std::make_shared<x509cert>(_selected_cert_bundle->cert);
    }
    _Type=sockettype::SSL;
};

netplus::ssl::ssl(const netplus::x509cert& cert,const std::string &addr,int port,int maxconnections,int sockopts) :
netplus::ssl::ssl(const std::map<std::string, CertificateBundle>& certs, const std::string &addr, int port, int maxconnections, int sockopts) :
    tcp(addr,port,maxconnections,sockopts),
    _cert_map(certs),
    _aes(nullptr),
    _handshakeDone(false)
{
     _cert=std::make_shared<x509cert >(cert);
    // Wähle das erste Zertifikat als Standard aus
    if (!certs.empty()) {
        _selected_cert_bundle = const_cast<CertificateBundle*>(&certs.begin()->second);
        _cert = std::make_shared<x509cert>(_selected_cert_bundle->cert);
    }
    _Type=sockettype::SSL;
}

@@ -1736,6 +1840,22 @@ void netplus::ssl::handshake_after_accept(){
            // ✅ CRITICAL: Save raw ClientHello bytes for transcript hash
            _clientHelloRawBytes = ch;

            // ✅ NEW: Versuche SNI zu extrahieren und das richtige Zertifikat auszuwählen
            std::string sni_hostname;
            if (extractSNIFromClientHello(ch, sni_hostname)) {
                _requested_hostname = sni_hostname;
                // Suche das Zertifikat in der Map
                auto it = _cert_map.find(sni_hostname);
                if (it != _cert_map.end()) {
                    _selected_cert_bundle = const_cast<CertificateBundle*>(&it->second);
                    _cert = std::make_shared<x509cert>(_selected_cert_bundle->cert);
                    SSL_LOG("[SNI] Selected cert for domain: " << sni_hostname);
                } else {
                    // Fallback: Verwende das erste Zertifikat
                    SSL_LOG("[SNI] No cert found for domain: " << sni_hostname << ", using default");
                }
            }

            auto readU16 = [&](size_t& p) -> uint16_t {
                if (p + 2 > ch.size()) throwSSL(NetException::Error, "parse underrun u16");
                uint16_t x = (uint16_t(ch[p]) << 8) | ch[p+1];
+11 −2
Original line number Diff line number Diff line
@@ -67,10 +67,19 @@ int main() {
        std::cout << "Certificate loaded successfully" << std::endl;
        std::cout.flush();

        // Create SSL socket listening on 127.0.0.1:8443
        // Create SSL socket listening on 127.0.0.1:8443 with SNI support
        std::cout << "Creating SSL socket..." << std::endl;
        std::cout.flush();
        ssl serverSock(cert, "127.0.0.1", 8443, 1024, -1);
        
        // Erstelle Zertifikat-Map für SNI
        std::map<std::string, netplus::ssl::CertificateBundle> certs;
        netplus::ssl::CertificateBundle bundle;
        bundle.cert = cert;
        bundle.privateKeyDer = {};
        certs["localhost"] = bundle;
        certs["127.0.0.1"] = bundle;
        
        ssl serverSock(certs, "127.0.0.1", 8443, 1024, -1);
        std::cout << "SSL socket created successfully" << std::endl;
        std::cout.flush();
        
+8 −3
Original line number Diff line number Diff line
@@ -150,10 +150,15 @@ private:

int main() {
    try {
        // du brauchst irgendeine ssl instanz.
        // Wenn dein ssl ctor ein cert braucht: erzeuge dummy cert etc.
        // Erstelle eine SSL-Instanz mit einem Zertifikat-Map für SNI-Support
        netplus::x509cert cert;
        netplus::ssl ssl(cert);
        std::map<std::string, netplus::ssl::CertificateBundle> certs;
        netplus::ssl::CertificateBundle bundle;
        bundle.cert = cert;
        bundle.privateKeyDer = {};
        certs["localhost"] = bundle;
        
        netplus::ssl ssl(certs);

        netplus::TestSSL T(ssl);
        T.run_tls13_record_loopback();