Commit cdd1ca75 authored by Jan Köster's avatar Jan Köster
Browse files

edchsa added

parent c6f990a9
Loading
Loading
Loading
Loading
+231 −108
Original line number Diff line number Diff line
@@ -2122,9 +2122,14 @@ namespace netplus {
                        if (off + certLen <= certMsg.size()) {
                            std::vector<uint8_t> derCert(certMsg.begin() + off, certMsg.begin() + off + certLen);

                            // Extract RSA public key from server certificate
                            // Detect certificate key type and extract public key
                            netplus::x509cert serverCert;
                            serverCert.loadFromBuffer(derCert);
                            if (serverCert.isECKey()) {
                                server_cert_is_ecdsa = true;
                                // EC public key will be used in ECDHE; no per-cert key needed here
                            } else {
                                server_cert_is_ecdsa = false;
                                if (!serverCert.extractPublicKey(peer_rsa)) {
                                    NetException e;
                                    e[NetException::Error] << "Failed to extract RSA public key from server certificate";
@@ -2132,6 +2137,62 @@ namespace netplus {
                                }
                            }
                        }
                    }

                    hs_state = server_cert_is_ecdsa ? HsState::CLI_TLS12_WAIT_SKE : HsState::CLI_WAIT_SHD;
                    break;
                }

                // ---------------------------
                // 3b) TLS 1.2 ECDHE: parse ServerKeyExchange
                // ---------------------------
                case HsState::CLI_TLS12_WAIT_SKE:
                {
                    std::vector<uint8_t> msg;
                    try {
                        msg = fetchNextHandshakePlain();
                    } catch (NetException& e) {
                        if (e.getErrorType() == NetException::Note)
                            note("ssl::handshake_after_connect: wait SKE");
                        throw;
                    }

                    if (msg.empty())
                        note("ssl::handshake_after_connect: wait SKE");

                    if (msg[0] == 0x0e) {
                        // ServerHelloDone without ServerKeyExchange — shouldn't happen for ECDHE
                        // but accept it gracefully; transition to CKE with empty server key
                        hs_state = HsState::CLI_SEND_CKE;
                        break;
                    }

                    if (msg[0] != 0x0c) {
                        NetException e;
                        e[NetException::Error] << "expected ServerKeyExchange (0x0c) got type=" << int(msg[0]);
                        throw e;
                    }

                    // Parse ServerKeyExchange body (skip 4-byte handshake header)
                    // Format: curve_type(1) + named_curve(2) + pubkey_len(1) + pubkey(pubkey_len)
                    //         + hash_alg(1) + sig_alg(1) + sig_len(2) + sig(sig_len)
                    if (msg.size() < 4 + 4) {
                        NetException e;
                        e[NetException::Error] << "ServerKeyExchange too short";
                        throw e;
                    }
                    size_t off = 4; // skip handshake header
                    uint8_t curve_type = msg[off++]; // 0x03 = named_curve
                    (void)curve_type;
                    off += 2; // skip named_curve (0x0017 = secp256r1)
                    uint8_t pubkey_len = msg[off++];
                    if (off + pubkey_len > msg.size()) {
                        NetException e;
                        e[NetException::Error] << "ServerKeyExchange pubkey truncated";
                        throw e;
                    }
                    tls12_srv_ecdhe_pub.assign(msg.begin() + off, msg.begin() + off + pubkey_len);
                    // Signature is after the key — skip verification (like cert verification)

                    hs_state = HsState::CLI_WAIT_SHD;
                    break;
@@ -2170,13 +2231,43 @@ namespace netplus {
                case HsState::CLI_SEND_CKE:
                {
                    if (!cli_cke_sent) {
                        std::vector<uint8_t> preMaster;
                        std::vector<uint8_t> ckeBody;

                        if (server_cert_is_ecdsa && !tls12_srv_ecdhe_pub.empty()) {
                            // ECDHE-ECDSA: generate ephemeral client key pair and compute shared secret
                            uint8_t cli_priv[32];
                            for (int retry = 0; retry < 10; ++retry) {
                                fillRandomBytes(cli_priv, 32);
                                bool nz = false;
                                for (int j = 0; j < 32; ++j) if (cli_priv[j]) { nz = true; break; }
                                if (nz) break;
                            }
                            netplus::P256Point cli_pub_pt = netplus::scalar_mul_G(cli_priv);
                            std::vector<uint8_t> cli_pub = netplus::encode_tls_point(cli_pub_pt);

                            netplus::P256Point srv_pub;
                            if (!netplus::decode_tls_point(srv_pub, tls12_srv_ecdhe_pub.data(), tls12_srv_ecdhe_pub.size()))
                                throwSSL(NetException::Error, "ECDHE CKE: invalid server public key");

                            uint8_t shared32[32];
                            if (!netplus::ecdh_shared_secret(shared32, cli_priv, srv_pub))
                                throwSSL(NetException::Error, "ECDHE CKE: shared secret failed");

                            preMaster.assign(shared32, shared32 + 32);

                            // CKE body: point_length(1) + uncompressed EC point(65)
                            ckeBody.push_back(uint8_t(cli_pub.size()));
                            ckeBody.insert(ckeBody.end(), cli_pub.begin(), cli_pub.end());
                        } else {
                            // RSA key exchange
                            // 1) Build PreMasterSecret: version(2) + random(46) = 48 bytes
                        std::vector<uint8_t> preMaster(48);
                            preMaster.resize(48);
                            preMaster[0] = 0x03;
                            preMaster[1] = 0x03; // TLS 1.2
                            fillRandomBytes(preMaster.data() + 2, 46);

                        // 2) PKCS#1 v1.5 encrypt: 0x00 0x02 <non-zero padding> 0x00 <data>
                            // 2) PKCS#1 v1.5 encrypt
                            const size_t kBytes = (peer_rsa.n.bitLength() + 7) / 8;
                            if (kBytes < 11 + 48)
                                throwSSL(NetException::Error, "Server RSA key too small");
@@ -2185,50 +2276,69 @@ namespace netplus {
                            emBlock[0] = 0x00;
                            emBlock[1] = 0x02;
                            size_t padLen = kBytes - 3 - 48;
                        // Fill padding with non-zero random bytes
                            fillRandomBytes(emBlock.data() + 2, padLen);
                            for (size_t i = 0; i < padLen; ++i) {
                                while (emBlock[2 + i] == 0x00)
                                    fillRandomBytes(&emBlock[2 + i], 1);
                            }
                        emBlock[2 + padLen] = 0x00; // separator
                            emBlock[2 + padLen] = 0x00;
                            std::memcpy(emBlock.data() + 2 + padLen + 1, preMaster.data(), 48);

                        // 3) RSA encrypt
                            rsa::bigInt plainBI = rsa::bigIntFromBytesBE(emBlock.data(), kBytes);
                            rsa::bigInt cipherBI = peer_rsa.encrypt(plainBI);
                            std::vector<uint8_t> ciphertext = rsa::bigIntToBytesBE(cipherBI, kBytes);

                        // 4) Build CKE body: uint16 encLen + ciphertext
                        std::vector<uint8_t> ckeBody;
                            ckeBody.reserve(2 + ciphertext.size());
                            ckeBody.push_back(uint8_t(ciphertext.size() >> 8));
                            ckeBody.push_back(uint8_t(ciphertext.size() & 0xFF));
                            ckeBody.insert(ckeBody.end(), ciphertext.begin(), ciphertext.end());
                        }

                        // 5) Send ClientKeyExchange handshake (adds to transcript)
                        sendHandshake(0x10, ckeBody);

                        // 6) Derive master_secret
                        // Derive master_secret
                        std::vector<uint8_t> msSeed = clientRandom;
                        msSeed.insert(msSeed.end(), serverRandom.begin(), serverRandom.end());
                        masterSecret = prf(preMaster, "master secret", msSeed, 48);

                        // 7) Derive key_block: client_mac(20) + server_mac(20) + client_key(16) + server_key(16) = 72
                        // Derive key_block
                        std::vector<uint8_t> kbSeed = serverRandom;
                        kbSeed.insert(kbSeed.end(), clientRandom.begin(), clientRandom.end());
                        std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, 72);

                        bool isGcm = (chosenSuite == 0xC02B || chosenSuite == 0xC02C);
                        if (isGcm) {
                            size_t keyLen = (chosenSuite == 0xC02C) ? 32 : 16;
                            std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, keyLen + keyLen + 4 + 4);
                            size_t k = 0;
                            std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                            std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                            std::memcpy(tls12_client_write_iv, keyBlock.data() + k, 4); k += 4;
                            std::memcpy(tls12_server_write_iv, keyBlock.data() + k, 4);
                            if (keyLen == 32) {
                                aes      = std::make_unique<aes256>(clientKey);
                                aes_recv = std::make_unique<aes256>(serverKey);
                            } else {
                                aes      = std::make_unique<aes128>(clientKey);
                                aes_recv = std::make_unique<aes128>(serverKey);
                            }
                            tls12_is_gcm = true;
                        } else {
                            // CBC: client_mac(20) + server_mac(20) + client_key + server_key
                            size_t keyLen = (chosenSuite == 0xC00A || chosenSuite == 0x0035) ? 32 : 16;
                            std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, 20 + 20 + keyLen + keyLen);
                            size_t k = 0;
                            client_mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20;
                            mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20);        k += 20;

                        std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); k += 16;
                        std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + 16);

                        // Client sends with clientKey, receives with serverKey
                            std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                            std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen);
                            if (keyLen == 32) {
                                aes      = std::make_unique<aes256>(clientKey);
                                aes_recv = std::make_unique<aes256>(serverKey);
                            } else {
                                aes      = std::make_unique<aes128>(clientKey);
                                aes_recv = std::make_unique<aes128>(serverKey);
                            }
                        }

                        cli_cke_sent = true;
                    }
@@ -3813,8 +3923,11 @@ namespace netplus {
            }

            if (recType == 0x15) {
                // Alert
                // Alert: level + description
                NetException e;
                if (rec.size() >= 7)
                    e[NetException::Error] << "tls: received alert level=" << int(rec[5]) << " desc=" << int(rec[6]);
                else
                    e[NetException::Error] << "tls: received alert";
                throw e;
            }
@@ -4482,8 +4595,11 @@ namespace netplus {
        ch.push_back(uint8_t(client_session_id.size()));
        ch.insert(ch.end(), client_session_id.begin(), client_session_id.end());

        // Cipher suites: TLS 1.3 + TLS 1.2
        std::vector<uint16_t> suites = {0x1301, 0x1302, 0x002F, 0x0035};
        // Cipher suites: TLS 1.3 + ECDHE-ECDSA (GCM + CBC) + RSA fallback
        // In TLS 1.2-only mode skip the TLS 1.3 ciphers
        std::vector<uint16_t> suites = client_tls12_only
            ? std::vector<uint16_t>{0xC02B, 0xC02C, 0xC009, 0xC00A, 0x002F, 0x0035}
            : std::vector<uint16_t>{0x1301, 0x1302, 0xC02B, 0xC02C, 0xC009, 0xC00A, 0x002F, 0x0035};
        ch.push_back(uint8_t((suites.size() * 2) >> 8));
        ch.push_back(uint8_t((suites.size() * 2) & 0xFF));
        for (uint16_t s : suites) {
@@ -4529,6 +4645,7 @@ namespace netplus {
            }
        }

        if (!client_tls12_only) {
            // 2) supported_versions extension (TLS 1.3 + TLS 1.2)
            {
                exts.push_back(0x00); exts.push_back(0x2b); // type
@@ -4537,14 +4654,21 @@ namespace netplus {
                exts.push_back(0x03); exts.push_back(0x04);  // TLS 1.3
                exts.push_back(0x03); exts.push_back(0x03);  // TLS 1.2
            }
        }

        // 3) supported_groups extension (X25519 + secp256r1)
        // 3) supported_groups extension (secp256r1; also x25519 when not TLS 1.2-only)
        {
            exts.push_back(0x00); exts.push_back(0x0a); // type
            if (!client_tls12_only) {
                exts.push_back(0x00); exts.push_back(0x06); // length = 6
                exts.push_back(0x00); exts.push_back(0x04); // named_group_list length = 4
                exts.push_back(0x00); exts.push_back(0x1d); // x25519
                exts.push_back(0x00); exts.push_back(0x17); // secp256r1
            } else {
                exts.push_back(0x00); exts.push_back(0x04); // length = 4
                exts.push_back(0x00); exts.push_back(0x02); // named_group_list length = 2
                exts.push_back(0x00); exts.push_back(0x17); // secp256r1 only
            }
        }

        // 4) signature_algorithms extension
@@ -4560,7 +4684,8 @@ namespace netplus {
            exts.push_back(0x08); exts.push_back(0x05); // rsa_pss_rsae_sha384
        }

        // 5) key_share extension (X25519 + P-256)
        if (!client_tls12_only) {
            // 5) key_share extension (X25519 + P-256) — TLS 1.3 only
            {
                // Generate X25519 ephemeral key pair
                client_priv_x25519.resize(32);
@@ -4580,20 +4705,17 @@ namespace netplus {
                    if (u256_cmp(k, P256_N) < 0) break;
                }
                netplus::P256Point cli_pub_pt = netplus::scalar_mul_G(client_priv_ecdhe);
            std::vector<uint8_t> p256_pub = netplus::encode_tls_point(cli_pub_pt); // 65 bytes (0x04 + x + y)
                std::vector<uint8_t> p256_pub = netplus::encode_tls_point(cli_pub_pt); // 65 bytes

                // Build key_share entries
                std::vector<uint8_t> shares;
            // X25519 entry: group(2) + key_len(2) + key(32)
                shares.push_back(0x00); shares.push_back(0x1d); // x25519
                shares.push_back(0x00); shares.push_back(0x20); // 32 bytes
                shares.insert(shares.end(), x25519_pub.begin(), x25519_pub.end());
            // P-256 entry: group(2) + key_len(2) + key(65)
                shares.push_back(0x00); shares.push_back(0x17); // secp256r1
                shares.push_back(0x00); shares.push_back(uint8_t(p256_pub.size()));
                shares.insert(shares.end(), p256_pub.begin(), p256_pub.end());

            // key_share extension: type(2) + ext_len(2) + client_shares_len(2) + shares
                uint16_t shares_len = uint16_t(shares.size());
                uint16_t ext_data_len = 2 + shares_len;
                exts.push_back(0x00); exts.push_back(0x33); // type = key_share
@@ -4601,9 +4723,10 @@ namespace netplus {
                exts.push_back(uint8_t(shares_len >> 8)); exts.push_back(uint8_t(shares_len & 0xFF));
                exts.insert(exts.end(), shares.begin(), shares.end());
            }
        }

        // Mark that we offered TLS 1.3
        client_offered_tls13 = true;
        // Mark that we offered TLS 1.3 (only when not in TLS 1.2-only mode)
        client_offered_tls13 = !client_tls12_only;

        // Extensions length
        ch.push_back(uint8_t(exts.size() >> 8));
+5 −0
Original line number Diff line number Diff line
@@ -151,6 +151,7 @@ namespace netplus {
            CLI_TLS13_SEND_FINISHED,
            CLI_TLS13_HANDSHAKE_DONE,
            CLI_WAIT_CERT,
            CLI_TLS12_WAIT_SKE,   // TLS 1.2 ECDHE: read optional ServerKeyExchange
            CLI_WAIT_SHD,
            CLI_SEND_CKE,
            CLI_SEND_CCS_FIN,
@@ -572,6 +573,10 @@ namespace netplus {
        // Use for legacy embedded devices that choke on key_share / supported_versions.
        bool client_tls12_only = false;

        // --- Client TLS 1.2 ECDHE state ---
        bool server_cert_is_ecdsa = false;
        std::vector<uint8_t> tls12_srv_ecdhe_pub;  // server ECDHE public key from ServerKeyExchange

        // --- Hostname (deprecated/legacy) ---
        mutable std::recursive_mutex mutex_;
        std::string hostname;
+48 −0
Original line number Diff line number Diff line
@@ -475,6 +475,54 @@ bool netplus::x509cert::extractPublicKey(netplus::rsa& outRsa) {
    return false;
}

// ------------------------------------------------------------
// isECKey -- returns true if SubjectPublicKeyInfo uses ecPublicKey OID (1.2.840.10045.2.1)
// ------------------------------------------------------------
bool netplus::x509cert::isECKey() const {
    auto& r = root();
    if (r.children.empty()) return false;
    auto& tbs = r.children[0];
    for (auto& child : tbs.children) {
        if (child.tag == 0x30 && child.children.size() >= 2) {
            if (child.children[0].tag == 0x30 && !child.children[0].children.empty()) {
                const ASN1Node& oidNode = child.children[0].children[0];
                if (oidNode.tag == 0x06) {
                    std::string oid = const_cast<x509cert*>(this)->decodeOID(oidNode.data, oidNode.len);
                    if (oid == "1.2.840.10045.2.1") return true;
                }
            }
        }
    }
    return false;
}

// ------------------------------------------------------------
// extractECPublicKey -- extracts raw EC point from SubjectPublicKeyInfo BIT STRING
// ------------------------------------------------------------
bool netplus::x509cert::extractECPublicKey(std::vector<uint8_t>& outPubRaw) {
    auto& r = root();
    if (r.children.empty()) return false;
    auto& tbs = r.children[0];
    for (auto& child : tbs.children) {
        if (child.tag == 0x30 && child.children.size() >= 2) {
            if (child.children[0].tag == 0x30 && !child.children[0].children.empty()) {
                const ASN1Node& oidNode = child.children[0].children[0];
                if (oidNode.tag != 0x06) continue;
                std::string oid = decodeOID(oidNode.data, oidNode.len);
                if (oid != "1.2.840.10045.2.1") continue;
                // Found ecPublicKey — extract the BIT STRING (skip first byte = unused bits count)
                for (auto& node : child.children) {
                    if (node.tag == 0x03 && node.data && node.len >= 2) {
                        outPubRaw.assign(node.data + 1, node.data + node.len);
                        return true;
                    }
                }
            }
        }
    }
    return false;
}

// ------------------------------------------------------------
// getSubjectCN
// ------------------------------------------------------------
+2 −0
Original line number Diff line number Diff line
@@ -70,6 +70,8 @@ namespace netplus {
        bool checkValidity();
        time_t getNotAfter();
        bool extractPublicKey(netplus::rsa& outRsa);
        bool isECKey() const;
        bool extractECPublicKey(std::vector<uint8_t>& outPubRaw);
        std::string getSubjectCN();
        std::vector<std::string> getSubjectAltNames();
        std::string decodeOID(const uint8_t* data, size_t len);