Commit 271004ae authored by jan.koester's avatar jan.koester
Browse files
parents 1fc759cc a3d27b0a
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -225,7 +225,8 @@ namespace netplus {
    }


    rsa::rsa(){
    rsa::rsa() : n(64), e(64), d(64) {
        // Explicitly initialize bigInt members with capacity 64
    }

    // Copy constructor
+29 −3
Original line number Diff line number Diff line
@@ -263,19 +263,24 @@ namespace netplus {
                try {
                    // Peer closed?
                    if (events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) {
                        std::cerr << "[EPOLL] Peer closed: events=" << events << std::endl;
                        needClose = true;
                    }
                    else {

                        // 1) TLS handshake
                        if (!c->csock->getHandshakeDone()) {
                        // 1) TLS handshake - loop until no more buffered data
                        while (!c->csock->getHandshakeDone()) {
                            std::cerr << "[EPOLL] Calling handshake_after_accept()" << std::endl;
                            c->csock->handshake_after_accept();
                            std::cerr << "[EPOLL] handshake_after_accept() returned" << std::endl;

                            if (c->csock->hasPendingWrite()) {
                                std::cerr << "[EPOLL] Flushing pending write data" << std::endl;
                                try {
                                    c->csock->flush_out();
                                } catch (NetException& e) {
                                    if (e.getErrorType() == NetException::Note) {
                                        std::cerr << "[EPOLL] flush_out threw Note" << std::endl;
                                        setpollEventsFd(fd,EPOLLOUT | EPOLLRDHUP | EPOLLONESHOT);
                                    } else {
                                        throw;
@@ -285,7 +290,18 @@ namespace netplus {

                            if (!c->csock->getHandshakeDone()) {
                                // still waiting
                                return;
                                // Check if there's buffered data to process first!
                                if (!c->csock->hasBufferedData()) {
                                    // No buffered data, re-arm socket for more events
                                    std::cerr << "[EPOLL] Handshake not done, re-arming socket for more events" << std::endl;
                                    int ev = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
                                    if (c->csock->hasPendingWrite())
                                        ev |= EPOLLOUT;
                                    setpollEventsFd(fd, ev);
                                    break; // Exit handshake loop, return to event handler
                                }
                                // else: has buffered data, loop again to process it
                                std::cerr << "[EPOLL] Buffered data available, continuing handshake" << std::endl;
                            }
                        }

@@ -298,6 +314,7 @@ namespace netplus {
                                rcv = c->csock->recvData(buf, 0);
                            } catch (NetException& e) {
                                if (e.getErrorType() == NetException::Note){
                                    rearm.disarm();
                                    setpollEventsFd(fd,EPOLLIN | EPOLLRDHUP | EPOLLONESHOT);
                                    return;
                                }
@@ -320,6 +337,7 @@ namespace netplus {
                                    c->csock->flush_out();
                                } catch (NetException& e) {
                                    if (e.getErrorType() == NetException::Note){
                                        rearm.disarm();
                                        setpollEventsFd(fd,EPOLLOUT | EPOLLRDHUP | EPOLLONESHOT);
                                        return;
                                    }
@@ -340,6 +358,7 @@ namespace netplus {

                                } catch (NetException& e) {
                                    if (e.getErrorType() == NetException::Note){
                                        rearm.disarm();
                                        setpollEventsFd(fd,EPOLLOUT | EPOLLRDHUP | EPOLLONESHOT);
                                        return;
                                    }
@@ -358,21 +377,28 @@ namespace netplus {
                    }

                } catch (NetException& e) {
                    std::cerr << "[EPOLL] Caught NetException: type=" << e.getErrorType() 
                              << " msg=" << e.what() << std::endl;
                    if (e.getErrorType() == NetException::Note){
                        // Only set EPOLLOUT if there's actually pending write data
                        std::cerr << "[EPOLL] It's a Note, re-arming socket" << std::endl;
                        int ev = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
                        if (c->csock->hasPendingWrite() || !c->SendData.empty())
                            ev |= EPOLLOUT;
                        rearm.disarm();
                        setpollEventsFd(fd, ev);
                        return;
                    }
                    std::cerr << "[EPOLL] Exception is not Note, will close" << std::endl;
                    needClose = true;
                } catch (...) {
                    std::cerr << "[EPOLL] Caught unknown exception" << std::endl;
                    needClose = true;
                }
            } // unlock event_mutex

            if (needClose) {
                std::cerr << "[EPOLL] Closing connection" << std::endl;
                rearm.disarm();
                CloseEventHandler(fd, tid, args);
            }
+8 −0
Original line number Diff line number Diff line
@@ -181,6 +181,7 @@ namespace netplus {
		virtual bool hasPendingWrite() const { return false; }
		virtual void setPendingWrite(bool pending) {}  // Only used for SSL/IOCP
		virtual bool getHandshakeDone() { return true; }
		virtual bool hasBufferedData() const { return false; }
		virtual int  getSocketType() const { return _Type; }

		virtual void connect(const std::string& addr, int port, bool nonblock = false) = 0;
@@ -359,6 +360,9 @@ namespace netplus {
		// Public accessor for certificate (used by event loop for SSL accept)
		std::shared_ptr<netplus::x509cert> getCert() const { return _cert; }

		// Check if there's buffered data waiting to be processed
		bool hasBufferedData() const { return !_rx_tcp_buf.empty(); }

	private:
		// --- crypto helpers ---
		std::vector<uint8_t> _sha1_hash(const std::vector<uint8_t>& input);
@@ -512,12 +516,15 @@ namespace netplus {
		netplus::x509cert _peer_cert;
		std::shared_ptr<netplus::x509cert> _cert = nullptr;
		netplus::rsa _rsa;
		netplus::rsa _peer_rsa;  // Server's RSA public key (client mode)
		uint8_t _ec_priv[32] = {0};  // ECC P-256 private key (Big-Endian)
		bool _has_ec_key = false;
		std::string _hostname;
		bool     _secure_reneg = true;
		uint16_t _chosenSuite = 0x002F;
		bool     _ccs_received = false;
		bool     _cli_cke_sent = false;  // TLS 1.2 client: CKE already sent
		bool     _cli_ccs_fin_sent = false;  // TLS 1.2 client: CCS+Finished sent

		std::unique_ptr<netplus::aes> _aes = nullptr;
		std::unique_ptr<netplus::aes> _aes_recv = nullptr;
@@ -664,6 +671,7 @@ namespace netplus {
		HsState _hs_state = HsState::READ_CLIENT_HELLO;

		std::vector<uint8_t> _masterSecret;
		std::vector<uint8_t> _tls12_transcript_before_client_finished;  // Saved for server Finished PRF
		std::vector<uint8_t> _handshake_transcript;
		std::vector<uint8_t> _clientHelloRawBytes;

+116 −60
Original line number Diff line number Diff line
@@ -351,37 +351,9 @@ netplus::ssl::ssl(std::shared_ptr<netplus::x509cert> cert, int sock) :

std::vector<uint8_t> netplus::ssl::readTlsRecordAsync()
{
    auto throwSSL = [&](int type, const std::string& msg) -> void {
        NetException e;
        e[type] << "ssl::readTlsRecordAsync: " << msg;
        throw e;
    };

    // -------- --------------------------------------------------------
    // 1) ensure we have at least TLS record header (5 bytes)
    // ----------------------------------------------------------------
#ifdef Windows
    while (_rx_tcp_buf.size() < 5) {
        buffer tmpBuf(4096);
        try {
            size_t r = tcp::recvData(tmpBuf, 0);
            if (r > 0) {
                _rx_tcp_buf.insert(_rx_tcp_buf.end(), tmpBuf.data.buf, tmpBuf.data.buf + r);
                continue;
            }
            NetException n;
            n[NetException::Note] << "ssl: record incomplete (need header)";
            throw n;
        } catch (NetException& e) {
            if (e.getErrorType() == NetException::Note) {
                NetException n;
                n[NetException::Note] << "ssl: record incomplete (need header)";
                throw n;
            }
            throw;
        }
    }
#else
    while (_rx_tcp_buf.size() < 5) {
        buffer tmpBuf(4096);
        try {
@@ -400,7 +372,6 @@ std::vector<uint8_t> netplus::ssl::readTlsRecordAsync()
            throw;
        }
    }
#endif

    // ------------------------------------------------------------
    // 2) parse record length
@@ -410,28 +381,6 @@ std::vector<uint8_t> netplus::ssl::readTlsRecordAsync()
    // ------------------------------------------------------------
    // 3) read until full record present
    // ------------------------------------------------------------
#ifdef Windows
    while (_rx_tcp_buf.size() < total) {
        buffer tmpBuf(4096);
        try {
            size_t r = tcp::recvData(tmpBuf, 0);
            if (r > 0) {
                _rx_tcp_buf.insert(_rx_tcp_buf.end(), tmpBuf.data.buf, tmpBuf.data.buf + r);
                continue;
            }
            NetException n;
            n[NetException::Note] << "ssl: record incomplete";
            throw n;
        } catch (NetException& e) {
            if (e.getErrorType() == NetException::Note) {
                NetException n;
                n[NetException::Note] << "ssl: record incomplete";
                throw n;
            }
            throw;
        }
    }
#else
    while (_rx_tcp_buf.size() < total) {
        buffer tmpBuf(4096);
        try {
@@ -450,7 +399,6 @@ std::vector<uint8_t> netplus::ssl::readTlsRecordAsync()
            throw;
        }
    }
#endif

    // ------------------------------------------------------------
    // 4) extract full record
@@ -1611,6 +1559,9 @@ void netplus::ssl::_tls13_send_finished(bool handshake_keys){

void netplus::ssl::handshake_after_accept(){

    std::cerr << "[SSL] ===== handshake_after_accept ENTER state=" << (int)_hs_state << std::endl;
    std::cerr.flush();

    auto throwSSL = [&](int type, const std::string& msg) {
        netplus::NetException e;
        e[type] << "ssl::accept: " << msg;
@@ -1680,6 +1631,7 @@ void netplus::ssl::handshake_after_accept(){
    // ------------------------------------------------------------

    // Run as far as possible in one call, until IO would block (Note thrown)
    try {
    for (;;) {
        std::cerr << "[SSL] handshake_after_accept loop: state=" << (int)_hs_state << std::endl;
        std::cerr.flush();
@@ -2179,7 +2131,12 @@ void netplus::ssl::handshake_after_accept(){
        // WAIT_CCS
        // ============================================================
        case HsState::WAIT_CCS: {
            std::cerr << "[SSL] WAIT_CCS: waiting for ChangeCipherSpec record" << std::endl;
            std::cerr.flush();
            
            if (_ccs_received) {
                std::cerr << "[SSL] WAIT_CCS: CCS already received, transitioning to WAIT_FIN" << std::endl;
                std::cerr.flush();
                _recv_seq = 0;
                _hs_state = HsState::WAIT_FIN;
                continue; // direkt WAIT_FIN probieren
@@ -2190,17 +2147,23 @@ void netplus::ssl::handshake_after_accept(){
                rec = readTlsRecordAsync();
            } catch (netplus::NetException& e) {
                if (e.getErrorType() == netplus::NetException::Note) {
                    TLSDBG("readTlsRecordAsync NOTE -> would block (wait EPOLLIN)");
                    std::cerr << "[SSL] WAIT_CCS: No data yet, returning" << std::endl;
                    std::cerr.flush();
                    return;
                }
                TLSDBG("readTlsRecordAsync ERROR");
                throw;
            }

            std::cerr << "[SSL] WAIT_CCS: Got record, size=" << rec.size() << " type=" << (int)rec[0] << std::endl;
            std::cerr.flush();

            if (rec.size() != 6 || rec[0] != 0x14 || rec[5] != 0x01)
                throwSSL(NetException::Error, "Expected CCS");

            std::cerr << "[SSL] WAIT_CCS: CCS received correctly, transitioning to WAIT_FIN" << std::endl;
            std::cerr.flush();

            _ccs_received = true;   // bleibt true bis WAIT_FIN erfolgreich war
            _recv_seq = 0;
            _hs_state = HsState::WAIT_FIN;
@@ -2240,10 +2203,13 @@ void netplus::ssl::handshake_after_accept(){
                throwSSL(NetException::Error, "client Finished verify_data mismatch");
            }

            // ✅ add client's Finished handshake message to transcript
            // ✅ Add client's Finished handshake message to transcript
            _handshake_transcript.insert(_handshake_transcript.end(),
                                        finPT.begin(), finPT.begin() + 4 + 12);

            // ✅ Save transcript AFTER adding client Finished (for server Finished PRF)
            _tls12_transcript_before_client_finished = _handshake_transcript;

            _hs_state = HsState::SEND_CCS_FIN;
            continue;  // process SEND_CCS_FIN immediately
        }
@@ -2260,8 +2226,9 @@ void netplus::ssl::handshake_after_accept(){

                _send_seq = 0; // ok - begin new crypto epoch after CCS

                std::vector<uint8_t> th2 = sha256_hash(_handshake_transcript);
                std::vector<uint8_t> verifyServer = _prf(_masterSecret, "server finished", th2, 12);
                // ✅ Use transcript after receiving all handshake messages (including client Finished) for server Finished PRF
                std::vector<uint8_t> th_server = sha256_hash(_tls12_transcript_before_client_finished);
                std::vector<uint8_t> verifyServer = _prf(_masterSecret, "server finished", th_server, 12);

                std::vector<uint8_t> sFin;
                sFin.reserve(4 + 12);
@@ -2297,9 +2264,19 @@ void netplus::ssl::handshake_after_accept(){
        }

        case HsState::WAIT_CKE: {
            std::cerr << "[SSL] WAIT_CKE: calling _fetchNextHandshakePlain()" << std::endl;
            std::cerr.flush();
            
            // ✅ Always read full handshake message (handles fragmentation)
            std::vector<uint8_t> msg = _fetchNextHandshakePlain();
            if (msg.empty()) return;  // would block
            std::cerr << "[SSL] WAIT_CKE: _fetchNextHandshakePlain returned " << msg.size() << " bytes" << std::endl;
            std::cerr.flush();
            
            if (msg.empty()) {
                std::cerr << "[SSL] WAIT_CKE: No data yet, returning" << std::endl;
                std::cerr.flush();
                return;  // would block
            }

            // msg = [type][len24][body...]
            if (msg.size() < 4)
@@ -2334,22 +2311,37 @@ void netplus::ssl::handshake_after_accept(){
                    " kBytes=" + std::to_string(kBytes));
            }

            std::cerr << "[SSL] WAIT_CKE: Decrypting RSA ciphertext (" << encLen << " bytes)..." << std::endl;
            std::cerr.flush();
            
            rsa::bigInt cipher = rsa::bigIntFromBytesBE(msg.data() + off, encLen);
            rsa::bigInt plainBI = _rsa.decrypt(cipher);

            std::cerr << "[SSL] WAIT_CKE: RSA decryption successful" << std::endl;
            std::cerr.flush();

            std::vector<uint8_t> pkcs1 = rsa::bigIntToBytesBE(plainBI, kBytes);
            std::vector<uint8_t> preMaster = extractPreMasterFromPkcs1(pkcs1);

            std::cerr << "[SSL] WAIT_CKE: Extracted premaster (" << preMaster.size() << " bytes)" << std::endl;
            std::cerr.flush();

            // master_secret = PRF(PMS, "master secret", client_random || server_random)
            std::vector<uint8_t> msSeed = _clientRandom;
            msSeed.insert(msSeed.end(), _serverRandom.begin(), _serverRandom.end());
            _masterSecret = _prf(preMaster, "master secret", msSeed, 48);

            std::cerr << "[SSL] WAIT_CKE: Master secret derived" << std::endl;
            std::cerr.flush();

            // key_block = PRF(master, "key expansion", server_random || client_random)
            std::vector<uint8_t> kbSeed = _serverRandom;
            kbSeed.insert(kbSeed.end(), _clientRandom.begin(), _clientRandom.end());
            std::vector<uint8_t> keyBlock = _prf(_masterSecret, "key expansion", kbSeed, 72);

            std::cerr << "[SSL] WAIT_CKE: Key block derived" << std::endl;
            std::cerr.flush();

            size_t k = 0;
            _client_mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20;
            _mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20);        k += 20;
@@ -2360,6 +2352,9 @@ void netplus::ssl::handshake_after_accept(){
            _aes_recv = std::make_unique<aes>(clientKey);
            _aes      = std::make_unique<aes>(serverKey);

            std::cerr << "[SSL] WAIT_CKE: AES ciphers initialized, transitioning to WAIT_CCS" << std::endl;
            std::cerr.flush();

            _hs_state = HsState::WAIT_CCS;
            return;
        }
@@ -2624,6 +2619,28 @@ void netplus::ssl::handshake_after_accept(){
            throwSSL(netplus::NetException::Error, "invalid handshake state");
        }
    }
    } catch (NetException& e) {
        if (e.getErrorType() == NetException::Note) {
            std::cerr << "[SSL] ===== handshake_after_accept EXIT with Note (would block)" << std::endl;
            std::cerr.flush();
            throw; // re-throw Note (wait for IO)
        }
        std::cerr << "[SSL] ===== handshake_after_accept EXIT with NetException: type=" << e.getErrorType() 
                  << " msg=" << e.what() << std::endl;
        std::cerr.flush();
        throw;
    } catch (std::exception& e) {
        std::cerr << "[SSL] ===== handshake_after_accept EXIT with std::exception: " << e.what() << std::endl;
        std::cerr.flush();
        throw;
    } catch (...) {
        std::cerr << "[SSL] ===== handshake_after_accept EXIT with unknown exception" << std::endl;
        std::cerr.flush();
        throw;
    }
    
    std::cerr << "[SSL] ===== handshake_after_accept EXIT OK state=" << (int)_hs_state << " done=" << _handshakeDone << std::endl;
    std::cerr.flush();
}

void netplus::ssl::_tls13_insert_message_hash_if_needed(){
@@ -3564,9 +3581,48 @@ void netplus::ssl::handshake_after_connect(){
                    throw e;
                }

                // parse cert chain (your existing code)
                // -> _peer_cert = ...
                // -> extract pubkey, etc.
                // Parse Certificate message:
                // [0] = type (0x0b)
                // [1..3] = length24 of entire body
                // [4..6] = certificates_list_length (3 bytes)
                // Then for each certificate:
                //   [3 bytes] = certificate_length
                //   [cert_len bytes] = DER certificate
                if (certMsg.size() < 7) {
                    NetException e;
                    e[NetException::Error] << "Certificate message too short";
                    throw e;
                }

                size_t off = 4; // skip handshake header
                uint32_t listLen = (uint32_t(certMsg[off]) << 16) | (uint32_t(certMsg[off+1]) << 8) | uint32_t(certMsg[off+2]);
                off += 3;

                if (off + listLen > certMsg.size()) {
                    NetException e;
                    e[NetException::Error] << "Certificate list truncated";
                    throw e;
                }

                // Parse first certificate (server's leaf cert)
                if (listLen >= 3) {
                    uint32_t certLen = (uint32_t(certMsg[off]) << 16) | (uint32_t(certMsg[off+1]) << 8) | uint32_t(certMsg[off+2]);
                    off += 3;

                    if (off + certLen <= certMsg.size()) {
                        std::vector<uint8_t> derCert(certMsg.begin() + off, certMsg.begin() + off + certLen);
                        _peer_cert.loadFromBuffer(derCert);

                        // Extract RSA public key from certificate
                        netplus::rsa peerRsa;
                        if (!_peer_cert.extractPublicKey(peerRsa)) {
                            NetException e;
                            e[NetException::Error] << "Failed to extract RSA public key from server certificate";
                            throw e;
                        }
                        _peer_rsa = peerRsa;
                    }
                }

                _hs_state = HsState::CLI_WAIT_SHD;
                break;