Commit 0febec05 authored by jan.koester's avatar jan.koester
Browse files

test

parent 0d467331
Loading
Loading
Loading
Loading
+107 −42
Original line number Diff line number Diff line
@@ -173,25 +173,22 @@ namespace netplus {
        size_t newUsed = used + wordShift + (bitShift ? 1 : 0);
        reserve(newUsed);

        // temp copy of current limbs
        std::vector<uint32_t> old(used);
        std::copy(data.get(), data.get() + used, old.data());

        // zero destination
        std::fill(data.get(), data.get() + newUsed, 0);

        uint32_t carry = 0;
        for (size_t i = 0; i < old.size(); ++i) {
            uint64_t v = old[i];
            if (bitShift) {
                uint64_t shifted = (v << bitShift) | carry;
                data[i + wordShift] = (uint32_t)shifted;
                carry = (uint32_t)(v >> (32 - bitShift));
        // Shift words and bits in-place, working from high end to low
        if (bitShift == 0) {
            // Pure word shift — move backwards to avoid overwriting
            for (size_t i = used; i-- > 0; )
                data[i + wordShift] = data[i];
        } else {
                data[i + wordShift] = (uint32_t)v;
            // Combined word + bit shift
            for (size_t i = used; i-- > 0; ) {
                uint64_t v = static_cast<uint64_t>(data[i]) << bitShift;
                data[i + wordShift + 1] |= static_cast<uint32_t>(v >> 32);
                data[i + wordShift] = static_cast<uint32_t>(v);
            }
        }
        if (bitShift) data[old.size() + wordShift] = carry;
        // Zero the low words freed by the shift
        for (size_t i = 0; i < wordShift; ++i)
            data[i] = 0;

        used = newUsed;
        while (used > 1 && data[used - 1] == 0) used--;
@@ -746,15 +743,32 @@ namespace netplus {
        // n_prime = -N^{-1} mod 2^32 (uses least significant limb)
        uint32_t n_prime = netplus::rsa::calculateNPrime(mod.data[0]);

        // Helper: compute (2^(64*mod.used)) mod mod == R^2 mod N
        // Helper: compute R^2 mod N where R = 2^(32*mod.used)
        // Uses repeated doubling: start with 1, double 2*32*mod.used times, mod N each time.
        // This avoids creating a huge intermediate integer.
        auto calculateR2Mod = [&](const bigInt& N) -> bigInt {
            bigInt r(N.used * 4 + 2);
            bigInt r(N.used + 2);
            r.fromUint64(1);
            r.shiftLeft(N.used * 64);

            bigInt q(r.used + 1), rem(N.used + 2);
            netplus::rsa::divide(r, N, q, rem);
            return rem;
            size_t totalBits = N.used * 64; // need R^2 = 2^(2*32*N.used)
            for (size_t i = 0; i < totalBits; ++i) {
                // r = r * 2 (shift left by 1)
                uint32_t carry = 0;
                for (size_t w = 0; w < r.used; ++w) {
                    uint32_t newCarry = r.data[w] >> 31;
                    r.data[w] = (r.data[w] << 1) | carry;
                    carry = newCarry;
                }
                if (carry) {
                    r.reserve(r.used + 1);
                    r.data[r.used] = carry;
                    r.used++;
                }
                // r = r mod N
                if (compare(r, N) >= 0) {
                    subtract(r, N, r);
                }
            }
            return r;
        };

        size_t n = mod.used;
@@ -1012,7 +1026,12 @@ namespace netplus {
            bool composite = true;

            for (size_t r = 1; r < s; r++) {
                x = modPow(x, two, n);
                // x = x^2 mod n  (direct squaring, much faster than modPow)
                bigInt x2(x.used * 2 + 2);
                multiply(x, x, x2);
                bigInt q_sq(x2.used + 1), r_sq(n.capacity + 1);
                divide(x2, n, q_sq, r_sq);
                x = std::move(r_sq);

                if (compare(x, nm1) == 0) {
                    composite = false;
@@ -1027,28 +1046,69 @@ namespace netplus {
    }

    void rsa::findPrime(bigInt& p, size_t bits) {
        // Correctly initialize 'two' to avoid ambiguous constructor call
        bigInt two(2U, 1);

        // Small primes for trial division (skip obviously composite candidates)
        static const uint32_t smallPrimes[] = {
            3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,
            79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,
            163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,
            241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,
            337,347,349,353,359,367,373,379,383,389,397,401,409,419,421,
            431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,
            521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,
            617,619,631,641,643,647,653,659,661,673,677,683,691,701,709,
            719,727,733,739,743,751,757,761,769,773,787,797,809,811,821,
            823,827,829,839,853,857,859,863,877,881,883,887,907,911,919,
            929,937,941,947,953,967,971,977,983,991,997
        };
        static const size_t numSmallPrimes = sizeof(smallPrimes) / sizeof(smallPrimes[0]);

        // Lambda: check trial division by small primes
        auto passesTrialDivision = [&](const bigInt& candidate) -> bool {
            for (size_t i = 0; i < numSmallPrimes; ++i) {
                // Compute candidate mod smallPrimes[i] using 64-bit arithmetic
                uint64_t rem = 0;
                uint32_t sp = smallPrimes[i];
                for (size_t w = candidate.used; w-- > 0; ) {
                    rem = (rem << 32) | candidate.data[w];
                    rem %= sp;
                }
                if (rem == 0) return false;
            }
            return true;
        };

        // Initial random start (ensured odd by generateSecureRandom)
        generateSecureRandom(p, bits);

        // Miller-Rabin Primality Test Loop
        while (!isProbablyPrime(p, 10)) {
            // We can add directly to p to avoid creating 'next_p' every time
            // if your add function supports add(a, b, a)
            bigInt next_p(p.capacity);
            add(p, two, next_p);
            p = next_p;
        // Search: skip composites with trial division, then confirm with Miller-Rabin
        for (;;) {
            if (passesTrialDivision(p) && isProbablyPrime(p, 10))
                return;
            // Add 2 in-place
            uint64_t carry = 2;
            for (size_t i = 0; i < p.used && carry; ++i) {
                uint64_t sum = (uint64_t)p.data[i] + carry;
                p.data[i] = (uint32_t)sum;
                carry = sum >> 32;
            }
            if (carry) {
                p.reserve(p.used + 1);
                p.data[p.used] = (uint32_t)carry;
                p.used++;
            }
        }
    }

    void rsa::generateKeys(size_t bit_count) {
        // bit_count for p and q should be half of the target N bit length
        bigInt p(bit_count), q(bit_count);
        // bit_count is the target size of N (e.g. 2048).
        // p and q should each be half of that so n = p*q has bit_count bits.
        size_t half_bits = bit_count / 2;
        bigInt p(half_bits / 32 + 2), q(half_bits / 32 + 2);

        findPrime(p, bit_count);
        findPrime(q, bit_count);
        findPrime(p, half_bits);
        findPrime(q, half_bits);

        // n = p * q
        n.reserve(p.used + q.used);
@@ -1275,11 +1335,16 @@ namespace netplus {
    std::vector<uint8_t> rsa::bigIntToBytesBE(const bigInt& x, size_t outLen){
        std::vector<uint8_t> out(outLen, 0x00);

        bigInt t = x;

        for (size_t i = 0; i < outLen; i++) {
            out[outLen - 1 - i] = uint8_t(t.data[0] & 0xFF);
            t.shiftRight(8);
        // Direct word extraction — O(n) instead of O(n²) shiftRight loop
        for (size_t i = 0; i < x.used && i * 4 < outLen; ++i) {
            uint32_t word = x.data[i];
            for (size_t b = 0; b < 4; ++b) {
                size_t bytePos = i * 4 + b;
                if (bytePos < outLen) {
                    out[outLen - 1 - bytePos] = static_cast<uint8_t>(word & 0xFF);
                    word >>= 8;
                }
            }
        }

        return out;