diff --git a/examples/bigint.hk b/examples/bigint.hk index c568950..5206045 100644 --- a/examples/bigint.hk +++ b/examples/bigint.hk @@ -4,9 +4,15 @@ import bigint; -let x = bigint.new("2222222222222222222222222222222222222222"); -let y = bigint.new("3333333333333333333333333333333333333333"); +let a = bigint.new("2222222222222222222222222222222222222222"); +let b = bigint.new(3); +let c = bigint.from_string("f", 16); +let d = bigint.new(7); -let z = bigint.add(x, y); +let x = bigint.add(a, b); +let y = bigint.sub(x, c); +let z = bigint.mod(y, d); +println(bigint.to_string(x)); +println(bigint.to_string(y)); println(bigint.to_string(z)); diff --git a/examples/ecdsa.hk b/examples/ecdsa.hk new file mode 100644 index 0000000..830160c --- /dev/null +++ b/examples/ecdsa.hk @@ -0,0 +1,246 @@ +// +// ecdsa.hk +// + +import bigint; +import crypto; +import hashing; + +// Structs + +struct Point { + x, y +} + +struct Curve { + keySize, + A, B, P, N, G +} + +// Functions + +fn new_point(x, y) { + return Point { + bigint.from_string(x, 16), + bigint.from_string(y, 16) + }; +} + +fn new_point_at_infinity() { + return Point { nil, nil }; +} + +fn is_at_infinity(p) { + return p.x == nil && p.y == nil; +} + +fn new_secp256k1_curve() { + let G = new_point( + "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798", + "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8" + ); + return Curve { + 32, + bigint.new(0), + bigint.new(7), + bigint.from_string( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F", + 16), + bigint.from_string( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141", + 16), + G + }; +} + +fn scalar_is_valid(c, scalar) { + return bigint.compare(scalar, 0) > 0 && bigint.compare(scalar, c.N) < 0; +} + +fn random_scalar(c) { + let size = c.keySize; + var bytes = crypto.random_bytes(size); + var scalar = bigint.from_bytes(bytes); + while (!scalar_is_valid(c, scalar)) + { + bytes = crypto.random_bytes(size); + scalar = bigint.from_bytes(bytes); + } + return scalar; +} + +fn is_on_curve(c, p) { + // p.y ^ 2 mod c.P = p.x ^ 3 + c.A * p.x + c.B mod c.P + let t0 = bigint.mod(bigint.pow(p.y, 2), c.P); + let t1 = bigint.pow(p.x, 3); + let t2 = bigint.mul(c.A, p.x); + let t3 = bigint.add(t1, bigint.add(t2, c.B)); + let t4 = bigint.mod(t3, c.P); + return bigint.compare(t0, t4) == 0; +} + +fn compute_y(c, x) { + // y0 = sqrtm_prime(x ^ 3 + c.A * x + c.B, c.P) + // y1 = -y0 mod c.P + let t0 = bigint.pow(x, 3); + let t1 = bigint.mul(c.A, x); + let t2 = bigint.add(t0, bigint.add(t1, c.B)); + let t3 = bigint.sqrtm_prime(t2, c.P); + let t4 = bigint.mod(bigint.neg(t3), c.P); + return [t3, t4]; +} + +fn add_points(c, p, q) { + if (is_at_infinity(p)) { + return q; + } + if (is_at_infinity(q)) { + return p; + } + // lambda = mod((q.y − p.y) * invertm(q.x − p.x, c.P), c.P) + var t0 = bigint.sub(q.y, p.y); + var t1 = bigint.sub(q.x, p.x); + let t2 = bigint.invertm(t1, c.P); + assert(t2, "invertm() failed"); + let t3 = bigint.mul(t0, t2); + let lambda = bigint.mod(t3, c.P); + // x = mod(lambda ^ 2 − p.x − q.x, c.P) + t0 = bigint.sub(bigint.pow(lambda, 2), p.x); + let x = bigint.mod(bigint.sub(t0, q.x), c.P); + // y = mod(lambda * (p.x − x) − p.y, c.P) + t0 = bigint.sub(p.x, x); + t1 = bigint.mul(lambda, t0); + let y = bigint.mod(bigint.sub(t1, p.y), c.P); + return Point { x, y }; +} + +fn double_point(c, p) { + if (is_at_infinity(p) || bigint.compare(p.y, 0) == 0) { + return new_point_at_infinity(); + } + // lambda = mod((3 * p.x ^ 2 + c.A) * invertm(2 * p.y, c.P), c.P) + var t0 = bigint.mul(bigint.pow(p.x, 2), 3); + var t1 = bigint.add(t0, c.A); + let t2 = bigint.mul(p.y, 2); + let t3 = bigint.invertm(t2, c.P); + assert(t3, "invertm() failed"); + let lambda = bigint.mod(bigint.mul(t1, t3), c.P); + // x = mod(lambda ^ 2 − 2 * p.x, c.P) + t0 = bigint.pow(lambda, 2); + t1 = bigint.mul(p.x, 2); + let x = bigint.mod(bigint.sub(t0, t1), c.P); + // y = mod(lambda * (p.x − x) − p.y), c.P) + t0 = bigint.sub(p.x, x); + t1 = bigint.mul(lambda, t0); + let y = bigint.mod(bigint.sub(t1, p.y), c.P); + return Point { x, y }; +} + +fn multiply_point_by_scalar(c, p, scalar) +{ + var q = new_point_at_infinity(); + let n = bigint.size(scalar, 2); + for (var i = n - 1; i >= 0; i--) + { + q = double_point(c, q); + if (bigint.testbit(scalar, i) == 1) { + q = add_points(c, q, p); + } + } + return q; +} + +fn multiply_base_point_by_scalar(c, scalar) { + return multiply_point_by_scalar(c, c.G, scalar); +} + +fn sign(c, digest, privKey) { + if (!scalar_is_valid(c, privKey)) { + return false; + } + loop { + let k = random_scalar(c); + let p = multiply_base_point_by_scalar(c, k); + // r = mod(p.x, c.N) + let r = bigint.mod(p.x, c.N); + if (bigint.compare(r, 0) == 0) { + continue; + } + // s = mod(invertm(k, c.N) * (digest + r * privKey), c.N) + let t0 = bigint.mul(r, privKey); + let t1 = bigint.add(digest, t0); + let t2 = bigint.invertm(k, c.N); + assert(t2, "invertm() failed"); + let s = bigint.mod(bigint.mul(t2, t1), c.N); + if (bigint.compare(s, 0) == 0) { + continue; + } + return [r, s]; + } +} + +fn verify_signature(c, digest, pubKey, r, s) { + if (!scalar_is_valid(c, r) || !scalar_is_valid(c, s)) { + return false; + } + // w = invertm(s, c.N) + let w = bigint.invertm(s, c.N); + assert(w, "invertm() failed"); + // u1 = mod(digest * w, c.N) + let u1 = bigint.mod(bigint.mul(digest, w), c.N); + // u2 = mod(r * w, c.N) + let u2 = bigint.mod(bigint.mul(r, w), c.N); + // p = u1 * G + u2 * pubKey + let p = add_points(c, + multiply_base_point_by_scalar(c, u1), + multiply_point_by_scalar(c, pubKey, u2) + ); + if (is_at_infinity(p)) { + return false; + } + // v = mod(p.x, c.N) + let v = bigint.mod(p.x, c.N); + return bigint.compare(v, r) == 0; +} + +// Main + +let c = new_secp256k1_curve(); + +// Test compute_y() and is_on_curve() +let p = new_point( + "D440BDBA94C11761FD4FC419B5BF3F111B8F193A5168ACD33AA5525DC50B2F18", + "7E38B3F29FDF12904486A0BCCE8F5018B7B96B60661DAB6DC2CF73E843BEF6CE" +); +let isOnCurve = is_on_curve(c, p); +assert(isOnCurve, "Point must be on curve"); +let [ y0, y1 ] = compute_y(c, p.x); +let y0IsY = bigint.compare(y0, p.y) == 0; +let y1IsY = bigint.compare(y1, p.y) == 0; +assert(y0IsY || y1IsY, "Point must have valid y coordinate"); + +// Generate a random private key +let privKey = random_scalar(c); +println("Private key: " + bigint.to_string(privKey, 16)); + +// Get the public key +let pubKey = multiply_base_point_by_scalar(c, privKey); +println("Public key:"); +println(" x: " + bigint.to_string(pubKey.x, 16)); +println(" y: " + bigint.to_string(pubKey.y, 16)); + +// Hash the message +let message = "Hello, world!"; +var digest = hashing.sha256(message); +digest = bigint.from_bytes(digest); +println("Digest: " + bigint.to_string(digest, 16)); + +// Sign the digest +let [ r, s ] = sign(c, digest, privKey); +println("Signature:"); +println(" r: " + bigint.to_string(r, 16)); +println(" s: " + bigint.to_string(s, 16)); + +// Verify the signature +let isValid = verify_signature(c, digest, pubKey, r, s); +assert(isValid, "Signature must be valid"); diff --git a/extensions/bigint.c b/extensions/bigint.c index 2f15fdc..bece2f5 100644 --- a/extensions/bigint.c +++ b/extensions/bigint.c @@ -13,27 +13,46 @@ typedef struct mpz_t num; } BigInt; -static inline BigInt *bigint_new(mpz_t num); +static inline void sqrtm_prime(mpz_t r, mpz_t a, mpz_t p); +static inline BigInt *bigint_new(void); static void bigint_deinit(HkUserdata *udata); static void new_call(HkState *state, HkValue *args); -static void from_hex_call(HkState *state, HkValue *args); +static void from_string_call(HkState *state, HkValue *args); static void to_string_call(HkState *state, HkValue *args); +static void from_bytes_call(HkState *state, HkValue *args); +static void to_bytes_call(HkState *state, HkValue *args); +static void sign_call(HkState *state, HkValue *args); static void add_call(HkState *state, HkValue *args); static void sub_call(HkState *state, HkValue *args); static void mul_call(HkState *state, HkValue *args); static void div_call(HkState *state, HkValue *args); static void mod_call(HkState *state, HkValue *args); static void pow_call(HkState *state, HkValue *args); +static void powm_call(HkState *state, HkValue *args); static void sqrt_call(HkState *state, HkValue *args); +static void sqrtm_prime_call(HkState *state, HkValue *args); static void neg_call(HkState *state, HkValue *args); static void abs_call(HkState *state, HkValue *args); static void compare_call(HkState *state, HkValue *args); +static void invertm_call(HkState *state, HkValue *args); +static void size_call(HkState *state, HkValue *args); +static void testbit_call(HkState *state, HkValue *args); -static inline BigInt *bigint_new(mpz_t num) +static inline void sqrtm_prime(mpz_t r, mpz_t a, mpz_t p) +{ + mpz_t t; + mpz_init(t); + mpz_add_ui(t, p, 1); + mpz_div_ui(t, t, 4); + mpz_powm(r, a, t, p); + mpz_clear(t); +} + +static inline BigInt *bigint_new(void) { BigInt *bigint = (BigInt *) hk_allocate(sizeof(*bigint)); hk_userdata_init((HkUserdata *) bigint, bigint_deinit); - mpz_init_set(bigint->num, num); + mpz_init(bigint->num); return bigint; } @@ -44,28 +63,54 @@ static void bigint_deinit(HkUserdata *udata) static void new_call(HkState *state, HkValue *args) { - HkType types[] = { HK_TYPE_NUMBER, HK_TYPE_STRING }; - hk_state_check_argument_types(state, args, 1, 2, types); + HkType types[] = { HK_TYPE_NIL, HK_TYPE_NUMBER, HK_TYPE_STRING }; + hk_state_check_argument_types(state, args, 1, 3, types); hk_return_if_not_ok(state); HkValue val = args[1]; - mpz_t num; + BigInt *result = bigint_new(); if (hk_is_number(val)) - mpz_init_set_d(num, hk_as_number(val)); - else - mpz_init_set_str(num, hk_as_string(val)->chars, 10); - BigInt *bigint = bigint_new(num); - hk_state_push_userdata(state, (HkUserdata *) bigint); + { + hk_state_check_argument_int(state, args, 1); + hk_return_if_not_ok(state); + mpz_set_ui(result->num, (int64_t) hk_as_number(val)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + if (hk_is_string(val)) + { + int rc = mpz_set_str(result->num, hk_as_string(val)->chars, 10); + if (rc) + { + mpz_clear(result->num); + hk_state_push_nil(state); + return; + } + } + hk_state_push_userdata(state, (HkUserdata *) result); } -static void from_hex_call(HkState *state, HkValue *args) +static void from_string_call(HkState *state, HkValue *args) { hk_state_check_argument_string(state, args, 1); hk_return_if_not_ok(state); HkString *str = hk_as_string(args[1]); - mpz_t num; - mpz_init_set_str(num, str->chars, 16); - BigInt *bigint = bigint_new(num); - hk_state_push_userdata(state, (HkUserdata *) bigint); + HkValue val = args[2]; + int base = 10; + if (!hk_is_nil(val)) + { + hk_state_check_argument_int(state, args, 2); + hk_return_if_not_ok(state); + base = (int) hk_as_number(val); + } + BigInt *result = bigint_new(); + int rc = mpz_set_str(result->num, str->chars, base); + if (rc) + { + mpz_clear(result->num); + hk_state_push_nil(state); + return; + } + hk_state_push_userdata(state, (HkUserdata *) result); } static void to_string_call(HkState *state, HkValue *args) @@ -73,100 +118,235 @@ static void to_string_call(HkState *state, HkValue *args) hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); - char *chars = mpz_get_str(NULL, 10, bigint->num); + HkValue val = args[2]; + int base = 10; + if (!hk_is_nil(val)) + { + hk_state_check_argument_int(state, args, 2); + hk_return_if_not_ok(state); + base = (int) hk_as_number(val); + } + char *chars = mpz_get_str(NULL, base, bigint->num); HkString *str = hk_string_from_chars(-1, chars); free(chars); hk_state_push_string(state, str); } +static void from_bytes_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_string(state, args, 1); + hk_return_if_not_ok(state); + HkString *str = hk_as_string(args[1]); + BigInt *result = bigint_new(); + mpz_import(result->num, str->length, 1, 1, 0, 0, str->chars); + hk_state_push_userdata(state, (HkUserdata *) result); +} + +static void to_bytes_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); + size_t length; + char *chars = mpz_export(NULL, &length, 1, 1, 0, 0, bigint->num); + HkString *str = hk_string_from_chars((int) length, chars); + free(chars); + hk_state_push_string(state, str); +} + +static void sign_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); + int sign = mpz_sgn(bigint->num); + hk_state_push_number(state, sign); +} + static void add_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_add_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_add(result, bigint1->num, bigint2->num); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_add(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void sub_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_sub_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_sub(result, bigint1->num, bigint2->num); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_sub(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void mul_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_mul_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_mul(result, bigint1->num, bigint2->num); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_mul(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void div_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_div_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_tdiv_q(result, bigint1->num, bigint2->num); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_div(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void mod_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_mod_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_tdiv_r(result, bigint1->num, bigint2->num); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_mod(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void pow_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + BigInt *result = bigint_new(); + mpz_pow_ui(result->num, bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); - mpz_t result; - mpz_init(result); - mpz_pow_ui(result, bigint1->num, mpz_get_ui(bigint2->num)); - BigInt *bigint = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + mpz_powm(result->num, bigint1->num, bigint2->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); +} + +static void powm_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + HkValue val3 = args[3]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2) && hk_is_int(val3)) + { + mpz_t num2; + mpz_init_set_ui(num2, (int64_t) hk_as_number(val2)); + mpz_t num3; + mpz_init_set_ui(num3, (int64_t) hk_as_number(val3)); + BigInt *result = bigint_new(); + mpz_powm(result->num, bigint1->num, num2, num3); + mpz_clear(num2); + mpz_clear(num3); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + if (hk_is_int(val2) && hk_is_userdata(val3)) + { + mpz_t num2; + mpz_init_set_ui(num2, (int64_t) hk_as_number(val2)); + BigInt *bigint3 = (BigInt *) hk_as_userdata(val3); + BigInt *result = bigint_new(); + mpz_powm(result->num, bigint1->num, num2, bigint3->num); + mpz_clear(num2); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + if (hk_is_userdata(val2) && hk_is_int(val3)) + { + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + mpz_t num3; + mpz_init_set_ui(num3, (int64_t) hk_as_number(val3)); + BigInt *result = bigint_new(); + mpz_powm(result->num, bigint1->num, bigint2->num, num3); + mpz_clear(num3); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + hk_state_check_argument_userdata(state, args, 2); + hk_return_if_not_ok(state); + hk_state_check_argument_userdata(state, args, 3); + hk_return_if_not_ok(state); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *bigint3 = (BigInt *) hk_as_userdata(val3); + BigInt *result = bigint_new(); + mpz_powm(result->num, bigint1->num, bigint2->num, bigint3->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void sqrt_call(HkState *state, HkValue *args) @@ -174,11 +354,34 @@ static void sqrt_call(HkState *state, HkValue *args) hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); - mpz_t result; - mpz_init(result); - mpz_sqrt(result, bigint->num); - BigInt *bigint2 = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint2); + BigInt *result = bigint_new(); + mpz_sqrt(result->num, bigint->num); + hk_state_push_userdata(state, (HkUserdata *) result); +} + +static void sqrtm_prime_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + mpz_t num2; + mpz_init_set_ui(num2, (int64_t) hk_as_number(val2)); + BigInt *result = bigint_new(); + sqrtm_prime(result->num, bigint1->num, num2); + mpz_clear(num2); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + hk_state_check_argument_userdata(state, args, 2); + hk_return_if_not_ok(state); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + sqrtm_prime(result->num, bigint1->num, bigint2->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void neg_call(HkState *state, HkValue *args) @@ -186,11 +389,9 @@ static void neg_call(HkState *state, HkValue *args) hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); - mpz_t result; - mpz_init(result); - mpz_neg(result, bigint->num); - BigInt *bigint2 = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint2); + BigInt *result = bigint_new(); + mpz_neg(result->num, bigint->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void abs_call(HkState *state, HkValue *args) @@ -198,40 +399,123 @@ static void abs_call(HkState *state, HkValue *args) hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); - mpz_t result; - mpz_init(result); - mpz_abs(result, bigint->num); - BigInt *bigint2 = bigint_new(result); - hk_state_push_userdata(state, (HkUserdata *) bigint2); + BigInt *result = bigint_new(); + mpz_abs(result->num, bigint->num); + hk_state_push_userdata(state, (HkUserdata *) result); } static void compare_call(HkState *state, HkValue *args) { hk_state_check_argument_userdata(state, args, 1); hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + int result = mpz_cmp_ui(bigint1->num, (int64_t) hk_as_number(val2)); + hk_state_push_number(state, result); + return; + } hk_state_check_argument_userdata(state, args, 2); hk_return_if_not_ok(state); - BigInt *bigint1 = (BigInt *) hk_as_userdata(args[1]); - BigInt *bigint2 = (BigInt *) hk_as_userdata(args[2]); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); int result = mpz_cmp(bigint1->num, bigint2->num); hk_state_push_number(state, result); } +static void invertm_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + HkValue val1 = args[1]; + HkValue val2 = args[2]; + BigInt *bigint1 = (BigInt *) hk_as_userdata(val1); + if (hk_is_int(val2)) + { + mpz_t num2; + mpz_init_set_ui(num2, (int64_t) hk_as_number(val2)); + BigInt *result = bigint_new(); + int rc = mpz_invert(result->num, bigint1->num, num2); + if (!rc) + { + mpz_clear(num2); + hk_state_push_nil(state); + return; + } + mpz_clear(num2); + hk_state_push_userdata(state, (HkUserdata *) result); + return; + } + hk_state_check_argument_userdata(state, args, 2); + hk_return_if_not_ok(state); + BigInt *bigint2 = (BigInt *) hk_as_userdata(val2); + BigInt *result = bigint_new(); + int rc = mpz_invert(result->num, bigint1->num, bigint2->num); + if (!rc) + { + hk_state_push_nil(state); + return; + } + hk_state_push_userdata(state, (HkUserdata *) result); +} + +static void size_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); + HkValue val = args[2]; + int base = 10; + if (!hk_is_nil(val)) + { + hk_state_check_argument_int(state, args, 2); + hk_return_if_not_ok(state); + base = (int) hk_as_number(val); + } + int result = mpz_sizeinbase(bigint->num, base); + hk_state_push_number(state, result); +} + +static void testbit_call(HkState *state, HkValue *args) +{ + hk_state_check_argument_userdata(state, args, 1); + hk_return_if_not_ok(state); + hk_state_check_argument_int(state, args, 2); + hk_return_if_not_ok(state); + BigInt *bigint = (BigInt *) hk_as_userdata(args[1]); + int index = (int) hk_as_number(args[2]); + int result = mpz_tstbit(bigint->num, index); + hk_state_push_number(state, result); +} + HK_LOAD_MODULE_HANDLER(bigint) { hk_state_push_string_from_chars(state, -1, "bigint"); hk_return_if_not_ok(state); hk_state_push_string_from_chars(state, -1, "new"); hk_return_if_not_ok(state); - hk_state_push_new_native(state, "new", 1, new_call); + hk_state_push_new_native(state, "new", 2, new_call); hk_return_if_not_ok(state); - hk_state_push_string_from_chars(state, -1, "from_hex"); + hk_state_push_string_from_chars(state, -1, "from_string"); hk_return_if_not_ok(state); - hk_state_push_new_native(state, "from_hex", 1, from_hex_call); + hk_state_push_new_native(state, "from_string", 2, from_string_call); hk_return_if_not_ok(state); hk_state_push_string_from_chars(state, -1, "to_string"); hk_return_if_not_ok(state); - hk_state_push_new_native(state, "to_string", 1, to_string_call); + hk_state_push_new_native(state, "to_string", 2, to_string_call); + hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "from_bytes"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "from_bytes", 1, from_bytes_call); + hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "to_bytes"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "to_bytes", 1, to_bytes_call); + hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "sign"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "sign", 1, sign_call); hk_return_if_not_ok(state); hk_state_push_string_from_chars(state, -1, "add"); hk_return_if_not_ok(state); @@ -257,10 +541,18 @@ HK_LOAD_MODULE_HANDLER(bigint) hk_return_if_not_ok(state); hk_state_push_new_native(state, "pow", 2, pow_call); hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "powm"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "powm", 3, powm_call); + hk_return_if_not_ok(state); hk_state_push_string_from_chars(state, -1, "sqrt"); hk_return_if_not_ok(state); hk_state_push_new_native(state, "sqrt", 1, sqrt_call); hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "sqrtm_prime"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "sqrtm_prime", 2, sqrtm_prime_call); + hk_return_if_not_ok(state); hk_state_push_string_from_chars(state, -1, "neg"); hk_return_if_not_ok(state); hk_state_push_new_native(state, "neg", 1, neg_call); @@ -273,5 +565,17 @@ HK_LOAD_MODULE_HANDLER(bigint) hk_return_if_not_ok(state); hk_state_push_new_native(state, "compare", 2, compare_call); hk_return_if_not_ok(state); - hk_state_construct(state, 13); + hk_state_push_string_from_chars(state, -1, "invertm"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "invertm", 2, invertm_call); + hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "size"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "size", 2, size_call); + hk_return_if_not_ok(state); + hk_state_push_string_from_chars(state, -1, "testbit"); + hk_return_if_not_ok(state); + hk_state_push_new_native(state, "testbit", 2, testbit_call); + hk_return_if_not_ok(state); + hk_state_construct(state, 21); }