Commit a66e3cae authored by jan.koester's avatar jan.koester
Browse files

test

parent 6fee4429
Loading
Loading
Loading
Loading
+137 −118
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@
#include <algorithm>
#include <stdexcept>
#include <sstream>
#include <fstream>

#include "random.h"
#include "rsa.h"
@@ -261,6 +262,141 @@ namespace netplus {
        return *this;
    }

    // Helper function to parse DER data and load RSA key
    static bool parseDERAndLoadRSA(const std::vector<uint8_t>& derData, netplus::rsa& out_rsa) {
        if (derData.empty()) {
            return false;
        }
        
        const uint8_t* data = derData.data();
        size_t size = derData.size();
        size_t cursor = 0;
        
        // Must start with SEQUENCE (0x30)
        if (size < 2 || data[cursor] != 0x30) {
            return false;
        }
        cursor++;
        
        // Parse SEQUENCE length
        size_t seqLength = 0;
        size_t lenBytesRead = 0;
        if (!parseDERLength(data + cursor, size - cursor, seqLength, lenBytesRead)) {
            return false;
        }
        cursor += lenBytesRead;
        
        if (cursor + seqLength > size) {
            return false;
        }
        
        std::vector<uint8_t> n, e, d;
        
        // Try to parse as PKCS#1 RSAPrivateKey first
        size_t tempCursor = cursor;
        bool isPKCS1Private = false;
        
        // Check for version (INTEGER 0)
        if (tempCursor < cursor + seqLength && data[tempCursor] == 0x02) {
            std::vector<uint8_t> version;
            size_t intBytesRead = 0;
            if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), version, intBytesRead)) {
                tempCursor += intBytesRead;
                
                // Try to parse n, e, d
                if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), n, intBytesRead)) {
                    tempCursor += intBytesRead;
                    
                    if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), e, intBytesRead)) {
                        tempCursor += intBytesRead;
                        
                        if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), d, intBytesRead)) {
                            isPKCS1Private = true;
                        }
                    }
                }
            }
        }
        
        if (isPKCS1Private && !n.empty() && !e.empty() && !d.empty()) {
            // Successfully parsed PKCS#1 private key
            out_rsa.setRsaKeyFromRaw(n, e, d);
            return true;
        }
        
        // Try to parse as public key (might be SubjectPublicKeyInfo with BIT STRING, or just RSAPublicKey SEQUENCE)
        if (data[cursor] == 0x30) {
            // Could be either RSAPublicKey or algorithm identifier
            cursor++;
            size_t innerSeqLength = 0;
            if (!parseDERLength(data + cursor, size - cursor, innerSeqLength, lenBytesRead)) {
                return false;
            }
            cursor += lenBytesRead;
            
            if (cursor + innerSeqLength > size) {
                return false;
            }
            
            // Try to parse modulus and exponent
            size_t pubCursor = cursor;
            size_t intBytesRead = 0;
            if (parseDERInteger(data + pubCursor, innerSeqLength - (pubCursor - cursor), n, intBytesRead)) {
                pubCursor += intBytesRead;
                
                if (parseDERInteger(data + pubCursor, innerSeqLength - (pubCursor - cursor), e, intBytesRead)) {
                    if (!n.empty() && !e.empty()) {
                        // Create a public-only key
                        d.clear();  // No private exponent
                        out_rsa.setRsaKeyFromRaw(n, e, d);
                        return true;
                    }
                }
            }
        }
        
        return false;
    }

    // Assignment from buffer
    rsa& rsa::operator=(const std::vector<uint8_t>& buffer) {
        if (buffer.empty()) {
            throw std::runtime_error("Buffer is empty");
        }

        if (!parseDERAndLoadRSA(buffer, *this)) {
            throw std::runtime_error("Failed to parse RSA key from DER buffer");
        }

        return *this;
    }

    // Assignment from filepath
    rsa& rsa::operator=(const std::string& filepath) {
        std::ifstream file(filepath, std::ios::binary | std::ios::ate);
        if (!file.is_open()) {
            throw std::runtime_error("Cannot open file: " + filepath);
        }

        std::streamsize fileSize = file.tellg();
        if (fileSize <= 0) {
            throw std::runtime_error("File is empty: " + filepath);
        }

        file.seekg(0, std::ios::beg);
        std::vector<uint8_t> derData(static_cast<size_t>(fileSize));
        
        if (!file.read(reinterpret_cast<char*>(derData.data()), fileSize)) {
            throw std::runtime_error("Failed to read file: " + filepath);
        }

        if (!parseDERAndLoadRSA(derData, *this)) {
            throw std::runtime_error("Failed to parse RSA key from file: " + filepath);
        }

        return *this;
    }

    // R = 2^(32*nwords), so R2 = R*R mod N
    rsa::bigInt rsa::calculateR2Mod(const bigInt& mod) {
        bigInt r(mod.used * 4 + 2);
@@ -973,123 +1109,6 @@ namespace netplus {
    }

    // Static function to load RSA key from DER-encoded data
    bool rsa::loadRsaFromDerFile(const std::vector<uint8_t>& derData, netplus::rsa& out_rsa) {
        if (derData.empty()) return false;
        
        // DER format for PKCS#1 RSAPrivateKey:
        // RSAPrivateKey ::= SEQUENCE {
        //   version           INTEGER,
        //   modulus           INTEGER,     -- n
        //   publicExponent    INTEGER,     -- e
        //   privateExponent   INTEGER,     -- d
        //   prime1            INTEGER,     -- p
        //   prime2            INTEGER,     -- q
        //   exponent1         INTEGER,     -- d mod (p-1)
        //   exponent2         INTEGER,     -- d mod (q-1)
        //   coefficient       INTEGER      -- (inverse of q) mod p
        // }
        //
        // DER format for PKCS#8 PrivateKeyInfo:
        // PrivateKeyInfo ::= SEQUENCE {
        //   version             INTEGER,
        //   algorithm           AlgorithmIdentifier,
        //   PrivateKey          OCTET STRING
        // }
        
        // For public key (SubjectPublicKeyInfo):
        // SubjectPublicKeyInfo ::= SEQUENCE {
        //   algorithm           AlgorithmIdentifier,
        //   subjectPublicKey    BIT STRING
        // }
        
        const uint8_t* data = derData.data();
        size_t size = derData.size();
        size_t cursor = 0;
        
        // Must start with SEQUENCE (0x30)
        if (size < 2 || data[cursor] != 0x30) {
            return false;
        }
        cursor++;
        
        // Parse SEQUENCE length
        size_t seqLength = 0;
        size_t lenBytesRead = 0;
        if (!parseDERLength(data + cursor, size - cursor, seqLength, lenBytesRead)) {
            return false;
        }
        cursor += lenBytesRead;
        
        if (cursor + seqLength > size) {
            return false;
        }
        
        std::vector<uint8_t> n, e, d;
        
        // Try to parse as PKCS#1 RSAPrivateKey first
        size_t tempCursor = cursor;
        bool isPKCS1Private = false;
        
        // Check for version (INTEGER 0)
        if (tempCursor < cursor + seqLength && data[tempCursor] == 0x02) {
            std::vector<uint8_t> version;
            size_t intBytesRead = 0;
            if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), version, intBytesRead)) {
                tempCursor += intBytesRead;
                
                // Try to parse n, e, d
                if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), n, intBytesRead)) {
                    tempCursor += intBytesRead;

                    if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), e, intBytesRead)) {
                        tempCursor += intBytesRead;
                        
                        if (parseDERInteger(data + tempCursor, seqLength - (tempCursor - cursor), d, intBytesRead)) {
                            isPKCS1Private = true;
                        }
                    }
                }
            }
        }
        
        if (isPKCS1Private && !n.empty() && !e.empty() && !d.empty()) {
            // Successfully parsed PKCS#1 private key
            out_rsa.setRsaKeyFromRaw(n, e, d);
            return true;
        }
        
        // Try to parse as public key (might be SubjectPublicKeyInfo with BIT STRING, or just RSAPublicKey SEQUENCE)
        if (data[cursor] == 0x30) {
            // Could be either RSAPublicKey or algorithm identifier
            cursor++;
            size_t innerSeqLength = 0;
            if (!parseDERLength(data + cursor, size - cursor, innerSeqLength, lenBytesRead)) {
                return false;
            }
            cursor += lenBytesRead;
            
            if (cursor + innerSeqLength > size) {
                return false;
            }
            
            // Try to parse modulus and exponent
            size_t pubCursor = cursor;
            size_t intBytesRead = 0;
            if (parseDERInteger(data + pubCursor, innerSeqLength - (pubCursor - cursor), n, intBytesRead)) {
                pubCursor += intBytesRead;
                
                if (parseDERInteger(data + pubCursor, innerSeqLength - (pubCursor - cursor), e, intBytesRead)) {
                    if (!n.empty() && !e.empty()) {
                        // Create a public-only key
                        d.clear();  // No private exponent
                        out_rsa.setRsaKeyFromRaw(n, e, d);
                        return true;
                    }
                }
            }
        }
        
        return false;
    }

};
+2 −2
Original line number Diff line number Diff line
@@ -86,6 +86,8 @@ namespace netplus {
        rsa(rsa&& src) noexcept;
        rsa& operator=(const rsa& src);
        rsa& operator=(rsa&& src) noexcept;
        rsa& operator=(const std::string& filepath);
        rsa& operator=(const std::vector<uint8_t>& buffer);

        explicit operator bool() const {
            return !(n.isZero() || d.isZero());
@@ -120,8 +122,6 @@ namespace netplus {
        static bigInt bigIntFromBytesBE (const uint8_t* bytes, size_t len);
        static std::vector<uint8_t> bigIntToBytesBE(const bigInt& x, size_t outLen);
        
        // Load RSA key from DER-encoded data
        static bool loadRsaFromDerFile(const std::vector<uint8_t>& derData, netplus::rsa& out_rsa);
    private:
        bigInt n, e, d;
        bool isProbablyPrime(const bigInt& n, int k);
+22 −6
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@
#include <cstring>
#include <memory>
#include <vector>
#include <fstream>
#include <cstdio>

#include "connection.h"
#include "eventapi.h"
@@ -109,9 +111,16 @@ int main() {
        netplus::ssl::CertificateBundle bundle1;
        bundle1.cert = cert;
        bundle1.privateKeyDer = std::vector<uint8_t>(test_key_der.begin(), test_key_der.end());
        // Pre-load RSA key into bundle using the rsa::loadRsaFromDerFile static function
        if (!netplus::rsa::loadRsaFromDerFile(bundle1.privateKeyDer, bundle1.rsa_key)) {
             std::cerr << "Failed to load RSA key 1!" << std::endl;
        // Pre-load RSA key into bundle using temporary file
        try {
            std::string temp_key1 = "/tmp/test_key1.der";
            std::ofstream temp_file1(temp_key1, std::ios::binary);
            temp_file1.write(reinterpret_cast<const char*>(bundle1.privateKeyDer.data()), bundle1.privateKeyDer.size());
            temp_file1.close();
            bundle1.rsa_key = temp_key1;
            std::remove(temp_key1.c_str());
        } catch (const std::exception& e) {
             std::cerr << "Failed to load RSA key 1: " << e.what() << std::endl;
             std::cerr.flush();
             return 1;
        }
@@ -122,9 +131,16 @@ int main() {
        netplus::ssl::CertificateBundle bundle2;
        bundle2.cert = cert2;
        bundle2.privateKeyDer = std::vector<uint8_t>(test2_key_der.begin(), test2_key_der.end());
        // Pre-load RSA key into bundle using the rsa::loadRsaFromDerFile static function
        if (!netplus::rsa::loadRsaFromDerFile(bundle2.privateKeyDer, bundle2.rsa_key)) {
             std::cerr << "Failed to load RSA key 2!" << std::endl;
        // Pre-load RSA key into bundle using temporary file
        try {
            std::string temp_key2 = "/tmp/test_key2.der";
            std::ofstream temp_file2(temp_key2, std::ios::binary);
            temp_file2.write(reinterpret_cast<const char*>(bundle2.privateKeyDer.data()), bundle2.privateKeyDer.size());
            temp_file2.close();
            bundle2.rsa_key = temp_key2;
            std::remove(temp_key2.c_str());
        } catch (const std::exception& e) {
             std::cerr << "Failed to load RSA key 2: " << e.what() << std::endl;
             std::cerr.flush();
             return 1;
        }