diff --git a/.cargo/config.toml b/.cargo/config.toml index 2024a54..44936b8 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,5 +1,5 @@ [env] RUST_LOG = "info" CAIRO_PATH = "cairo" -BOOTLOADER_PATH = "bootloader/recursive_with_poseidon/simple_bootloader.cairo" +BOOTLOADER_PATH = "bootloader/starknet/simple_bootloader.cairo" BOOTLOADER_OUT_NAME = "bootloader.json" \ No newline at end of file diff --git a/cairo/bootloader/recursive_with_poseidon/__init__.py b/cairo/bootloader/starknet/__init__.py similarity index 100% rename from cairo/bootloader/recursive_with_poseidon/__init__.py rename to cairo/bootloader/starknet/__init__.py diff --git a/cairo/bootloader/recursive_with_poseidon/builtins.py b/cairo/bootloader/starknet/builtins.py similarity index 83% rename from cairo/bootloader/recursive_with_poseidon/builtins.py rename to cairo/bootloader/starknet/builtins.py index 78f19ac..1bb01a3 100644 --- a/cairo/bootloader/recursive_with_poseidon/builtins.py +++ b/cairo/bootloader/starknet/builtins.py @@ -5,7 +5,9 @@ OUTPUT_BUILTIN, PEDERSEN_BUILTIN, RANGE_CHECK_BUILTIN, + ECDSA_BUILTIN, BITWISE_BUILTIN, + EC_OP_BUILTIN, POSEIDON_BUILTIN, ] ) diff --git a/cairo/bootloader/recursive_with_poseidon/execute_task.cairo b/cairo/bootloader/starknet/execute_task.cairo similarity index 92% rename from cairo/bootloader/recursive_with_poseidon/execute_task.cairo rename to cairo/bootloader/starknet/execute_task.cairo index 06bf2bd..ed3ff33 100644 --- a/cairo/bootloader/recursive_with_poseidon/execute_task.cairo +++ b/cairo/bootloader/starknet/execute_task.cairo @@ -2,9 +2,11 @@ from builtin_selection.inner_select_builtins import inner_select_builtins from builtin_selection.select_input_builtins import select_input_builtins from builtin_selection.validate_builtins import validate_builtins from common.builtin_poseidon.poseidon import PoseidonBuiltin, poseidon_hash_many -from common.cairo_builtins import HashBuiltin +from common.cairo_builtins import HashBuiltin, EcOpBuiltin from common.hash_chain import hash_chain +from common.bool import TRUE from common.registers import get_ap, get_fp_and_pc +from common.signature import check_ecdsa_signature const BOOTLOADER_VERSION = 0; @@ -28,7 +30,9 @@ struct BuiltinData { output: felt, pedersen: felt, range_check: felt, + ecdsa: felt, bitwise: felt, + ec_op: felt, poseidon: felt, } @@ -109,6 +113,21 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}( use_poseidon=bool(ids.use_poseidon)), 'Computed hash does not match input.' %} + local public_key: felt; + local signature_r: felt; + local signature_s: felt; + %{ + ids.public_key = simple_bootloader_input.job.public_key + ids.signature_r = simple_bootloader_input.job.signature_r + ids.signature_s = simple_bootloader_input.job.signature_s + %} + + let ec_op_ptr = cast(input_builtin_ptrs.ec_op, EcOpBuiltin*); + with ec_op_ptr { + let (res) = check_ecdsa_signature(message=hash, public_key=public_key, signature_r=signature_r, signature_s=signature_s); + assert res = TRUE; + } + // Set the program entry point, so the bootloader can later run the program. local builtin_list: felt* = &program_header.builtin_list; local n_builtins = program_header.n_builtins; @@ -127,7 +146,9 @@ func execute_task{builtin_ptrs: BuiltinData*, self_range_check_ptr}( output=output_ptr + 2, pedersen=cast(pedersen_ptr, felt), range_check=input_builtin_ptrs.range_check, + ecdsa=input_builtin_ptrs.ecdsa, bitwise=input_builtin_ptrs.bitwise, + ec_op=cast(ec_op_ptr, felt), poseidon=cast(poseidon_ptr, felt), ); diff --git a/cairo/bootloader/recursive_with_poseidon/run_simple_bootloader.cairo b/cairo/bootloader/starknet/run_simple_bootloader.cairo similarity index 94% rename from cairo/bootloader/recursive_with_poseidon/run_simple_bootloader.cairo rename to cairo/bootloader/starknet/run_simple_bootloader.cairo index 17ada90..efa2c24 100644 --- a/cairo/bootloader/recursive_with_poseidon/run_simple_bootloader.cairo +++ b/cairo/bootloader/starknet/run_simple_bootloader.cairo @@ -1,4 +1,4 @@ -from bootloader.recursive_with_poseidon.execute_task import BuiltinData, execute_task +from bootloader.starknet.execute_task import BuiltinData, execute_task from common.cairo_builtins import HashBuiltin, PoseidonBuiltin from common.registers import get_fp_and_pc @@ -14,7 +14,9 @@ func run_simple_bootloader{ output_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, + ecdsa_ptr, bitwise_ptr, + ec_op_ptr, poseidon_ptr: PoseidonBuiltin*, }() { alloc_locals; @@ -41,7 +43,9 @@ func run_simple_bootloader{ output=cast(output_ptr, felt), pedersen=cast(pedersen_ptr, felt), range_check=task_range_check_ptr, + ecdsa=ecdsa_ptr, bitwise=bitwise_ptr, + ec_op=ec_op_ptr, poseidon=cast(poseidon_ptr, felt), ); @@ -50,7 +54,9 @@ func run_simple_bootloader{ output='output', pedersen='pedersen', range_check='range_check', + ecdsa='ecdsa', bitwise='bitwise', + ec_op='ec_op', poseidon='poseidon', ); @@ -58,7 +64,9 @@ func run_simple_bootloader{ output=1, pedersen=3, range_check=1, + ecdsa=2, bitwise=5, + ec_op=7, poseidon=6, ); @@ -83,7 +91,9 @@ func run_simple_bootloader{ let output_ptr = cast(builtin_ptrs.output, felt*); let pedersen_ptr = cast(builtin_ptrs.pedersen, HashBuiltin*); let range_check_ptr = builtin_ptrs.range_check; + let ecdsa_ptr = builtin_ptrs.ecdsa; let bitwise_ptr = builtin_ptrs.bitwise; + let ec_op_ptr = builtin_ptrs.ec_op; let poseidon_ptr = cast(builtin_ptrs.poseidon, PoseidonBuiltin*); // 'execute_tasks' runs untrusted code and uses the range_check builtin to verify that diff --git a/cairo/bootloader/recursive_with_poseidon/simple_bootloader.cairo b/cairo/bootloader/starknet/simple_bootloader.cairo similarity index 90% rename from cairo/bootloader/recursive_with_poseidon/simple_bootloader.cairo rename to cairo/bootloader/starknet/simple_bootloader.cairo index efec7e5..f462a16 100644 --- a/cairo/bootloader/recursive_with_poseidon/simple_bootloader.cairo +++ b/cairo/bootloader/starknet/simple_bootloader.cairo @@ -1,6 +1,6 @@ -%builtins output pedersen range_check bitwise poseidon +%builtins output pedersen range_check ecdsa bitwise ec_op poseidon -from bootloader.recursive_with_poseidon.run_simple_bootloader import ( +from bootloader.starknet.run_simple_bootloader import ( run_simple_bootloader, ) from common.cairo_builtins import HashBuiltin, PoseidonBuiltin @@ -10,7 +10,9 @@ func main{ output_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, + ecdsa_ptr, bitwise_ptr, + ec_op_ptr, poseidon_ptr: PoseidonBuiltin*, }() { %{ diff --git a/cairo/common/bool.cairo b/cairo/common/bool.cairo new file mode 100644 index 0000000..c0ae06c --- /dev/null +++ b/cairo/common/bool.cairo @@ -0,0 +1,3 @@ +// Represents boolean values in Cairo. +const FALSE = 0; +const TRUE = 1; diff --git a/cairo/common/cairo_builtins.cairo b/cairo/common/cairo_builtins.cairo index baf3128..c3f47ee 100644 --- a/cairo/common/cairo_builtins.cairo +++ b/cairo/common/cairo_builtins.cairo @@ -1,6 +1,6 @@ -from starkware.cairo.common.ec_point import EcPoint -from starkware.cairo.common.keccak_state import KeccakBuiltinState -from starkware.cairo.common.poseidon_state import PoseidonBuiltinState +from common.ec_point import EcPoint +from common.keccak_state import KeccakBuiltinState +from common.poseidon_state import PoseidonBuiltinState // Specifies the hash builtin memory structure. struct HashBuiltin { diff --git a/cairo/common/ec.cairo b/cairo/common/ec.cairo new file mode 100644 index 0000000..411e972 --- /dev/null +++ b/cairo/common/ec.cairo @@ -0,0 +1,286 @@ +// Functions for various actions on the STARK curve: +// y^2 = x^3 + alpha * x + beta +// where alpha = 1 and beta = 0x6f21413efbe40de150e596d72f7a8c5609ad26c15c915c1f4cdfcb99cee9e89. +// The point at infinity is represented as (0, 0). + +from common.cairo_builtins import EcOpBuiltin +from common.ec_point import EcPoint +from common.math import is_quad_residue + +namespace StarkCurve { + const ALPHA = 1; + const BETA = 0x6f21413efbe40de150e596d72f7a8c5609ad26c15c915c1f4cdfcb99cee9e89; + const ORDER = 0x800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f; + const GEN_X = 0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca; + const GEN_Y = 0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f; +} + +// Asserts that an EC point is on the STARK curve. +// +// Arguments: +// p - an EC point. +func assert_on_curve(p: EcPoint) { + // Because the curve order is odd, there is no point (except (0, 0), which represents the point + // at infinity) with y = 0. + if (p.y == 0) { + assert p.x = 0; + return (); + } + tempvar rhs = (p.x * p.x + StarkCurve.ALPHA) * p.x + StarkCurve.BETA; + assert p.y * p.y = rhs; + return (); +} + +// Doubles a point (computes p + p) on the elliptic curve. +// +// Arguments: +// p - an EC point. +// +// Returns: +// r = p + p. +// +// Assumptions: +// p is a valid point on the curve. +func ec_double(p: EcPoint) -> (r: EcPoint) { + // (0, 0), which represents the point at infinity, is the only point with y = 0. + if (p.y == 0) { + return (r=p); + } + tempvar slope = (3 * p.x * p.x + StarkCurve.ALPHA) / (2 * p.y); + tempvar r_x = slope * slope - p.x - p.x; + return (r=EcPoint(x=r_x, y=slope * (p.x - r_x) - p.y)); +} + +// Adds two points on the EC. +// +// Arguments: +// p - an EC point. +// q - an EC point. +// +// Returns: +// r = p + q. +// +// Assumptions: +// p and q are valid points on the curve. +func ec_add(p: EcPoint, q: EcPoint) -> (r: EcPoint) { + // (0, 0), which represents the point at infinity, is the only point with y = 0. + if (p.y == 0) { + return (r=q); + } + if (q.y == 0) { + return (r=p); + } + if (p.x == q.x) { + if (p.y == q.y) { + return ec_double(p); + } + // In this case, because p and q are on the curve, p.y = -q.y. + return (r=EcPoint(x=0, y=0)); + } + tempvar slope = (p.y - q.y) / (p.x - q.x); + tempvar r_x = slope * slope - p.x - q.x; + return (r=EcPoint(x=r_x, y=slope * (p.x - r_x) - p.y)); +} + +// Subtracts a point from another on the EC. +// +// Arguments: +// p - an EC point. +// q - an EC point. +// +// Returns: +// r = p - q. +// +// Assumptions: +// p and q are valid points on the curve. +func ec_sub(p: EcPoint, q: EcPoint) -> (r: EcPoint) { + return ec_add(p=p, q=EcPoint(x=q.x, y=-q.y)); +} + +// Computes p + m * q on the elliptic curve. +// Because the EC operation builtin cannot handle inputs where additions of points with the same x +// coordinate arise during the computation, this function adds and subtracts a nondeterministic +// point s to the computation, so that regardless of input, the probability that such additions +// arise will become negligibly small. +// The precise computation is therefore: +// ((p + s) + m * q) - s +// so that the inputs to the builtin itself are (p + s), m, and q. +// +// Arguments: +// ec_op_ptr - the ec_op builtin pointer. +// p - an EC point. +// m - the multiplication coefficient of Q. +// q - an EC point. +// +// Returns: +// r = p + m * q. +// +// Assumptions: +// p and q are valid points on the curve. +func ec_op{ec_op_ptr: EcOpBuiltin*}(p: EcPoint, m: felt, q: EcPoint) -> (r: EcPoint) { + alloc_locals; + + // (0, 0), which represents the point at infinity, is the only point with y = 0. + if (q.y == 0) { + return (r=p); + } + + local s: EcPoint; + %{ + from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME + from starkware.python.math_utils import random_ec_point + from starkware.python.utils import to_bytes + + # Define a seed for random_ec_point that's dependent on all the input, so that: + # (1) The added point s is deterministic. + # (2) It's hard to choose inputs for which the builtin will fail. + seed = b"".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y])) + ids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed) + %} + let p_plus_s: EcPoint = ec_add(p, s); + + assert ec_op_ptr.p = p_plus_s; + assert ec_op_ptr.m = m; + assert ec_op_ptr.q = q; + let r: EcPoint = ec_add(ec_op_ptr.r, EcPoint(x=s.x, y=-s.y)); + let ec_op_ptr = ec_op_ptr + EcOpBuiltin.SIZE; + return (r=r); +} + +// Computes m * p on the elliptic curve. +// +// Arguments: +// ec_op_ptr - the ec_op builtin pointer. +// m - the multiplication coefficient of p. +// p - an EC point. +// +// Returns: +// r = m * p. +// +// Assumptions: +// p is a valid point on the curve. +func ec_mul{ec_op_ptr: EcOpBuiltin*}(m: felt, p: EcPoint) -> (r: EcPoint) { + return ec_op(p=EcPoint(x=0, y=0), m=m, q=p); +} + +// Computes p + m[0] * q[0] + m[1] * q[1] + ... m[len - 1] * q[len - 1] on the elliptic curve. +// Because the EC operation builtin cannot handle inputs where additions of points with the same x +// coordinate arise during the computation, this function adds and removes a nondeterministic +// point s to the computation, so that regardless of input, the probability that such additions +// arise will become negligibly small. +// The precise computation is therefore: +// ((p + s) + m[0] * q[0] + m[1] + q[1] + ... + m[len - 1] * q[len - 1]) - s. +// +// Arguments: +// ec_op_ptr - the ec_op builtin pointer. +// p - an EC point. +// m - an array of multiplication coefficients. +// q - an array of EC points. +// len - the number of points in q. +// +// Returns: +// r = p + m[0] * q[0] + m[1] * q[1] + ... + m[len - 1] * q[len - 1]. +// Assumptions: +// * All given EC points are on the STARK curve. +// * len <= 1000. +func chained_ec_op{ec_op_ptr: EcOpBuiltin*}(p: EcPoint, m: felt*, q: EcPoint*, len: felt) -> ( + r: EcPoint +) { + alloc_locals; + local s: EcPoint; + %{ + from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME + from starkware.python.math_utils import random_ec_point + from starkware.python.utils import to_bytes + + n_elms = ids.len + assert isinstance(n_elms, int) and n_elms >= 0, \ + f'Invalid value for len. Got: {n_elms}.' + if '__chained_ec_op_max_len' in globals(): + assert n_elms <= __chained_ec_op_max_len, \ + f'chained_ec_op() can only be used with len<={__chained_ec_op_max_len}. ' \ + f'Got: n_elms={n_elms}.' + + # Define a seed for random_ec_point that's dependent on all the input, so that: + # (1) The added point s is deterministic. + # (2) It's hard to choose inputs for which the builtin will fail. + seed = b"".join( + map( + to_bytes, + [ + ids.p.x, + ids.p.y, + *memory.get_range(ids.m, n_elms), + *memory.get_range(ids.q.address_, 2 * n_elms), + ], + ) + ) + ids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed) + %} + let p_plus_s: EcPoint = ec_add(p, s); + let r_plus_s: EcPoint = _chained_ec_op_inner(p=p_plus_s, m=m, q=q, len=len); + let r: EcPoint = ec_add(r_plus_s, EcPoint(x=s.x, y=-s.y)); + return (r=r); +} + +func _chained_ec_op_inner{ec_op_ptr: EcOpBuiltin*}( + p: EcPoint, m: felt*, q: EcPoint*, len: felt +) -> (r: EcPoint) { + if (len == 0) { + return (r=p); + } + // (0, 0), representing the point at infinity, is the only point for which y = 0. + if (q.y == 0) { + return _chained_ec_op_inner(p=p, m=&m[1], q=&q[1], len=len - 1); + } + assert ec_op_ptr.p = p; + assert ec_op_ptr.m = m[0]; + assert ec_op_ptr.q = q[0]; + let r = ec_op_ptr.r; + let ec_op_ptr = &ec_op_ptr[1]; + return _chained_ec_op_inner(p=r, m=&m[1], q=&q[1], len=len - 1); +} + +// Recovers the y coordinate of a point on the EC. +// +// Arguments: +// x - the x coordinate of an EC point. +// +// Returns: +// p - one of the two EC points with the given x coordinate (x, y). +// +// Assumptions: +// There exists y such that (x, y) is on the curve. Otherwise the function's hint will throw a +// python exception. +// +// Note: +// This function will fail on x = 0 because there is no such point on the curve. The point at +// infinity is represented as (0, 0), but this is just a representation, not actual coordinates. +func recover_y(x: felt) -> (p: EcPoint) { + alloc_locals; + local p: EcPoint; + %{ + from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME + from starkware.python.math_utils import recover_y + ids.p.x = ids.x + # This raises an exception if `x` is not on the curve. + ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME) + %} + assert p.x = x; + assert_on_curve(p); + return (p=p); +} + +// Checks if `x` represents the x coordinate of a point on the curve. +// +// Arguments: +// x - a field element. +// +// Returns: +// res - TRUE if `x` represents the x coordinate of a point on the curve, FALSE otherwise. +// Note: +// Returns FALSE on x = 0 because there is no such point on the curve. The point at +// infinity is represented as (0, 0), but this is just a representation, not actual coordinates. +func is_x_on_curve(x: felt) -> felt { + return is_quad_residue(x=x * x * x + StarkCurve.ALPHA * x + StarkCurve.BETA); +} diff --git a/cairo/common/math.cairo b/cairo/common/math.cairo new file mode 100644 index 0000000..d1b53a4 --- /dev/null +++ b/cairo/common/math.cairo @@ -0,0 +1,488 @@ +from common.bool import FALSE, TRUE + +// Inline functions with no locals. + +// Verifies that value != 0. The proof will fail otherwise. +func assert_not_zero(value) { + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.value) + assert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.' + %} + if (value == 0) { + // If value == 0, add an unsatisfiable requirement. + value = 1; + } + + return (); +} + +// Verifies that a != b. The proof will fail otherwise. +func assert_not_equal(a, b) { + %{ + from starkware.cairo.lang.vm.relocatable import RelocatableValue + both_ints = isinstance(ids.a, int) and isinstance(ids.b, int) + both_relocatable = ( + isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and + ids.a.segment_index == ids.b.segment_index) + assert both_ints or both_relocatable, \ + f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.' + assert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.' + %} + if (a == b) { + // If a == b, add an unsatisfiable requirement. + a = a + 1; + } + + return (); +} + +// Verifies that a >= 0 (or more precisely 0 <= a < RANGE_CHECK_BOUND). +func assert_nn{range_check_ptr}(a) { + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' + %} + a = [range_check_ptr]; + let range_check_ptr = range_check_ptr + 1; + return (); +} + +// Verifies that a <= b (or more precisely 0 <= b - a < RANGE_CHECK_BOUND). +func assert_le{range_check_ptr}(a, b) { + assert_nn(b - a); + return (); +} + +// Verifies that a <= b - 1 (or more precisely 0 <= b - 1 - a < RANGE_CHECK_BOUND). +func assert_lt{range_check_ptr}(a, b) { + assert_le(a, b - 1); + return (); +} + +// Verifies that 0 <= a <= b. +// +// Prover assumption: b < RANGE_CHECK_BOUND. +// +// This function is still sound without the prover assumptions. In that case, it is guaranteed +// that a < RANGE_CHECK_BOUND and b < 2 * RANGE_CHECK_BOUND. +func assert_nn_le{range_check_ptr}(a, b) { + assert_nn(a); + assert_le(a, b); + return (); +} + +// Asserts that value is in the range [lower, upper). +// Or more precisely: +// (0 <= value - lower < RANGE_CHECK_BOUND) and (0 <= upper - 1 - value < RANGE_CHECK_BOUND). +// +// Prover assumption: 0 <= upper - lower <= RANGE_CHECK_BOUND. +func assert_in_range{range_check_ptr}(value, lower, upper) { + assert_le(lower, value); + assert_le(value, upper - 1); + return (); +} + +// Asserts that 'value' is in the range [0, 2**250). +@known_ap_change +func assert_250_bit{range_check_ptr}(value) { + const UPPER_BOUND = 2 ** 250; + const SHIFT = 2 ** 128; + const HIGH_BOUND = UPPER_BOUND / SHIFT; + + let low = [range_check_ptr]; + let high = [range_check_ptr + 1]; + + %{ + from starkware.cairo.common.math_utils import as_int + + # Correctness check. + value = as_int(ids.value, PRIME) % PRIME + assert value < ids.UPPER_BOUND, f'{value} is outside of the range [0, 2**250).' + + # Calculation for the assertion. + ids.high, ids.low = divmod(ids.value, ids.SHIFT) + %} + + assert [range_check_ptr + 2] = HIGH_BOUND - 1 - high; + + // The assert below guarantees that + // value = high * SHIFT + low <= (HIGH_BOUND - 1) * SHIFT + 2**128 - 1 = + // HIGH_BOUND * SHIFT - SHIFT + SHIFT - 1 = 2**250 - 1. + assert value = high * SHIFT + low; + + let range_check_ptr = range_check_ptr + 3; + return (); +} + +// Splits the unsigned integer lift of a field element into the higher 128 bit and lower 128 bit. +// The unsigned integer lift is the unique integer in the range [0, PRIME) that represents the field +// element. +// For example, if value=17 * 2^128 + 8, then high=17 and low=8. +@known_ap_change +func split_felt{range_check_ptr}(value) -> (high: felt, low: felt) { + // Note: the following code works because PRIME - 1 is divisible by 2**128. + const MAX_HIGH = (-1) / 2 ** 128; + const MAX_LOW = 0; + + // Guess the low and high parts of the integer. + let low = [range_check_ptr]; + let high = [range_check_ptr + 1]; + let range_check_ptr = range_check_ptr + 2; + + %{ + from starkware.cairo.common.math_utils import assert_integer + assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128 + assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW + assert_integer(ids.value) + ids.low = ids.value & ((1 << 128) - 1) + ids.high = ids.value >> 128 + %} + assert value = high * (2 ** 128) + low; + if (high == MAX_HIGH) { + assert_le(low, MAX_LOW); + } else { + assert_le(high, MAX_HIGH - 1); + } + return (high=high, low=low); +} + +// Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +// or equal to that of b. +@known_ap_change +func assert_le_felt{range_check_ptr}(a, b) { + // ceil(PRIME / 3 / 2 ** 128). + const PRIME_OVER_3_HIGH = 0x2aaaaaaaaaaaab05555555555555556; + // ceil(PRIME / 2 / 2 ** 128). + const PRIME_OVER_2_HIGH = 0x4000000000000088000000000000001; + // The numbers [0, a, b, PRIME - 1] should be ordered. To prove that, we show that two of the + // 3 arcs {0 -> a, a -> b, b -> PRIME - 1} are small: + // One is less than PRIME / 3 + 2 ** 129. + // Another is less than PRIME / 2 + 2 ** 129. + // Since the sum of the lengths of these two arcs is less than PRIME, there is no wrap-around. + %{ + import itertools + + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert_integer(ids.b) + a = ids.a % PRIME + b = ids.b % PRIME + assert a <= b, f'a = {a} is not less than or equal to b = {b}.' + + # Find an arc less than PRIME / 3, and another less than PRIME / 2. + lengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)] + lengths_and_indices.sort() + assert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2 + excluded = lengths_and_indices[2][1] + + memory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = ( + divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH)) + memory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = ( + divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH)) + %} + // Guess two arc lengths. + tempvar arc_short = [range_check_ptr] + [range_check_ptr + 1] * PRIME_OVER_3_HIGH; + tempvar arc_long = [range_check_ptr + 2] + [range_check_ptr + 3] * PRIME_OVER_2_HIGH; + let range_check_ptr = range_check_ptr + 4; + + // First, choose which arc to exclude from {0 -> a, a -> b, b -> PRIME - 1}. + // Then, to compare the set of two arc lengths, compare their sum and product. + let arc_sum = arc_short + arc_long; + let arc_prod = arc_short * arc_long; + + // Exclude "0 -> a". + %{ memory[ap] = 1 if excluded != 0 else 0 %} + jmp skip_exclude_a if [ap] != 0, ap++; + assert arc_sum = (-1) - a; + assert arc_prod = (a - b) * (1 + b); + return (); + + // Exclude "a -> b". + skip_exclude_a: + %{ memory[ap] = 1 if excluded != 1 else 0 %} + jmp skip_exclude_b_minus_a if [ap] != 0, ap++; + tempvar m1mb = (-1) - b; + assert arc_sum = a + m1mb; + assert arc_prod = a * m1mb; + return (); + + // Exclude "b -> PRIME - 1". + skip_exclude_b_minus_a: + %{ assert excluded == 2 %} + assert arc_sum = b; + assert arc_prod = a * (b - a); + ap += 2; + return (); +} + +// Asserts that the unsigned integer lift (as a number in the range [0, PRIME)) of a is lower than +// that of b. +@known_ap_change +func assert_lt_felt{range_check_ptr}(a, b) { + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert_integer(ids.b) + assert (ids.a % PRIME) < (ids.b % PRIME), \ + f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.' + %} + if (a == b) { + // If a == b, add an unsatisfiable requirement. + a = a + 1; + } + assert_le_felt(a, b); + return (); +} + +// Returns the absolute value of value. +// Prover asumption: -rc_bound < value < rc_bound. +@known_ap_change +func abs_value{range_check_ptr}(value) -> felt { + tempvar is_positive: felt; + %{ + from starkware.cairo.common.math_utils import is_positive + ids.is_positive = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + if (is_positive == 0) { + tempvar new_range_check_ptr = range_check_ptr + 1; + tempvar abs_value = value * (-1); + [range_check_ptr] = abs_value; + let range_check_ptr = new_range_check_ptr; + return abs_value; + } else { + [range_check_ptr] = value; + let range_check_ptr = range_check_ptr + 1; + return value; + } +} + +// Returns the sign of value: -1, 0 or 1. +// Prover asumption: -rc_bound < value < rc_bound. +@known_ap_change +func sign{range_check_ptr}(value) -> felt { + if (value == 0) { + ap += 2; + return 0; + } + + tempvar is_positive: felt; + %{ + from starkware.cairo.common.math_utils import is_positive + ids.is_positive = 1 if is_positive( + value=ids.value, prime=PRIME, rc_bound=range_check_builtin.bound) else 0 + %} + if (is_positive == 0) { + assert [range_check_ptr] = value * (-1); + let range_check_ptr = range_check_ptr + 1; + return -1; + } else { + ap += 1; + [range_check_ptr] = value; + let range_check_ptr = range_check_ptr + 1; + return 1; + } +} + +// Returns q and r such that: +// 0 <= q < rc_bound, 0 <= r < div and value = q * div + r. +// +// Assumption: 0 < div <= PRIME / rc_bound. +// Prover assumption: value / div < rc_bound. +// +// The value of div is restricted to make sure there is no overflow. +// q * div + r < (q + 1) * div <= rc_bound * (PRIME / rc_bound) = PRIME. +func unsigned_div_rem{range_check_ptr}(value, div) -> (q: felt, r: felt) { + let r = [range_check_ptr]; + let q = [range_check_ptr + 1]; + let range_check_ptr = range_check_ptr + 2; + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.div) + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + ids.q, ids.r = divmod(ids.value, ids.div) + %} + assert_le(r, div - 1); + + assert value = q * div + r; + return (q, r); +} + +// Returns q and r such that. -bound <= q < bound, 0 <= r < div and value = q * div + r. +// value < PRIME / 2 is considered positive and value > PRIME / 2 is considered negative. +// +// Assumptions: +// 0 < div <= PRIME / (rc_bound) +// bound <= rc_bound / 2. +// Prover assumption: -bound <= value / div < bound. +// +// The values of div and bound are restricted to make sure there is no overflow. +// q * div + r < (q + 1) * div <= rc_bound / 2 * (PRIME / rc_bound) +// q * div + r >= q * div >= -rc_bound / 2 * (PRIME / rc_bound). +func signed_div_rem{range_check_ptr}(value, div, bound) -> (q: felt, r: felt) { + let r = [range_check_ptr]; + let biased_q = [range_check_ptr + 1]; // == q + bound. + let range_check_ptr = range_check_ptr + 2; + %{ + from starkware.cairo.common.math_utils import as_int, assert_integer + + assert_integer(ids.div) + assert 0 < ids.div <= PRIME // range_check_builtin.bound, \ + f'div={hex(ids.div)} is out of the valid range.' + + assert_integer(ids.bound) + assert ids.bound <= range_check_builtin.bound // 2, \ + f'bound={hex(ids.bound)} is out of the valid range.' + + int_value = as_int(ids.value, PRIME) + q, ids.r = divmod(int_value, ids.div) + + assert -ids.bound <= q < ids.bound, \ + f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).' + + ids.biased_q = q + ids.bound + %} + let q = biased_q - bound; + assert value = q * div + r; + assert_le(r, div - 1); + assert_le(biased_q, 2 * bound - 1); + return (q, r); +} + +// Computes value / div as integers and fails if value is not divisible by div. +// Namely, verifies that 1 <= div < PRIME / rc_bound +// and returns q such that: +// 0 <= q < rc_bound and q = value / div. +func safe_div{range_check_ptr}(value: felt, div: felt) -> felt { + // floor(PRIME / 2 ** 128). + const PRIME_OVER_RC_BOUND = 0x8000000000000110000000000000000; + assert [range_check_ptr] = div - 1; + assert [range_check_ptr + 1] = div + (2 ** 128 - PRIME_OVER_RC_BOUND); + // Prepare the result at the end of the stack. + let q = [ap + 1]; + q = value / div; + tempvar range_check_ptr = range_check_ptr + 3; + [range_check_ptr - 1] = q, ap++; + static_assert &q + 1 == ap; + return q; +} + +// Computes first * second if there is no overflow. +// Namely, returns the product of first and second if: +// 0 <= first < rc_bound and 0 <= second < PRIME / rc_bound +// and fails otherwise. +func safe_mult{range_check_ptr}(first: felt, second: felt) -> felt { + // floor(PRIME / 2 ** 128). + const PRIME_OVER_RC_BOUND = 0x8000000000000110000000000000000; + assert [range_check_ptr] = first; + assert [range_check_ptr + 1] = second; + assert [range_check_ptr + 2] = second + (2 ** 128 - PRIME_OVER_RC_BOUND); + let range_check_ptr = range_check_ptr + 3; + return first * second; +} + +// Splits the given (unsigned) value into n "limbs", where each limb is in the range [0, bound), +// as follows: +// value = x[0] + x[1] * base + x[2] * base**2 + ... + x[n - 1] * base**(n - 1). +// bound must be less than the range check bound (2**128). +// Note that bound may be smaller than base, in which case the function will fail if there is a +// limb which is >= bound. +// Assumptions: +// 1 < bound <= base +// base**n < field characteristic. +func split_int{range_check_ptr}(value, n, base, bound, output: felt*) { + if (n == 0) { + %{ assert ids.value == 0, 'split_int(): value is out of range.' %} + assert value = 0; + return (); + } + + %{ + memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base + assert res < ids.bound, f'split_int(): Limb {res} is out of range.' + %} + tempvar low_part = [output]; + assert_nn_le(low_part, bound - 1); + + return split_int( + value=(value - low_part) / base, n=n - 1, base=base, bound=bound, output=output + 1 + ); +} + +// Returns the floor value of the square root of the given value. +// Assumptions: 0 <= value < 2**250. +@known_ap_change +func sqrt{range_check_ptr}(value) -> felt { + alloc_locals; + local root: felt; + + %{ + from starkware.python.math_utils import isqrt + value = ids.value % PRIME + assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)." + assert 2 ** 250 < PRIME + ids.root = isqrt(value) + %} + + assert_nn_le(root, 2 ** 125 - 1); + tempvar root_plus_one = root + 1; + assert_in_range(value, root * root, root_plus_one * root_plus_one); + + return root; +} + +// Computes the evaluation of a polynomial on the given point. +func horner_eval(n_coefficients: felt, coefficients: felt*, point: felt) -> (res: felt) { + if (n_coefficients == 0) { + return (res=0); + } + + let (n_minus_one_res) = horner_eval( + n_coefficients=n_coefficients - 1, coefficients=&coefficients[1], point=point + ); + return (res=n_minus_one_res * point + coefficients[0]); +} + +// Returns TRUE if `x` is a quadratic residue modulo the STARK prime. Returns FALSE otherwise. +// Returns TRUE on 0. +@known_ap_change +func is_quad_residue(x: felt) -> felt { + alloc_locals; + local y; + %{ + from starkware.crypto.signature.signature import FIELD_PRIME + from starkware.python.math_utils import div_mod, is_quad_residue, sqrt + + x = ids.x + if is_quad_residue(x, FIELD_PRIME): + ids.y = sqrt(x, FIELD_PRIME) + else: + ids.y = sqrt(div_mod(x, 3, FIELD_PRIME), FIELD_PRIME) + %} + // Relies on the fact that 3 is not a quadratic residue modulo the prime, so for every field + // element x, either: + // * x is a quadratic residue and there exists y such that y^2 = x. + // * x is not a quadratic residue and there exists y such that 3 * y^2 = x. + tempvar y_squared = y * y; + if (y_squared == x) { + ap += 1; + return TRUE; + } else { + assert 3 * y_squared = x; + return FALSE; + } +} + +// Asserts that x = 2^n for some 0 <= n <= max_pow. +func assert_is_power_of_2(x: felt, max_pow: felt) { + if (max_pow == 0) { + assert x = 1; + } + if (x == 1) { + return (); + } + return assert_is_power_of_2(x=x / 2, max_pow=max_pow - 1); +} diff --git a/cairo/common/signature.cairo b/cairo/common/signature.cairo new file mode 100644 index 0000000..513f4e9 --- /dev/null +++ b/cairo/common/signature.cairo @@ -0,0 +1,83 @@ +from common.bool import FALSE, TRUE +from common.cairo_builtins import EcOpBuiltin, SignatureBuiltin +from common.ec import StarkCurve, ec_add, ec_mul, ec_sub, is_x_on_curve, recover_y +from common.ec_point import EcPoint + +// Verifies that the prover knows a signature of the given public_key on the given message. +// +// Prover assumption: (signature_r, signature_s) is a valid signature for the given public_key +// on the given message. +func verify_ecdsa_signature{ecdsa_ptr: SignatureBuiltin*}( + message, public_key, signature_r, signature_s +) { + %{ ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s)) %} + assert ecdsa_ptr.message = message; + assert ecdsa_ptr.pub_key = public_key; + + let ecdsa_ptr = ecdsa_ptr + SignatureBuiltin.SIZE; + return (); +} + +// Checks if (signature_r, signature_s) is a valid signature for the given public_key +// on the given message. +// +// Arguments: +// message - the signed message. +// public_key - the public key corresponding to the key with which the message was signed. +// signature_r - the r component of the ECDSA signature. +// signature_s - the s component of the ECDSA signature. +// +// Returns: +// res - TRUE if the signature is valid, FALSE otherwise. +func check_ecdsa_signature{ec_op_ptr: EcOpBuiltin*}( + message, public_key, signature_r, signature_s +) -> (res: felt) { + alloc_locals; + // Check that s != 0 (mod StarkCurve.ORDER). + if (signature_s == 0) { + return (res=FALSE); + } + if (signature_s == StarkCurve.ORDER) { + return (res=FALSE); + } + if (signature_r == StarkCurve.ORDER) { + return (res=FALSE); + } + + // Check that the public key is the x coordinate of a point on the curve. + let on_curve: felt = is_x_on_curve(public_key); + if (on_curve == FALSE) { + return (res=FALSE); + } + // Check that r is the x coordinate of a point on the curve. + // Note that this ensures that r != 0. + let on_curve: felt = is_x_on_curve(signature_r); + if (on_curve == FALSE) { + return (res=FALSE); + } + + // To verify ECDSA, obtain: + // zG = z * G, where z is the message and G is a generator of the EC. + // rQ = r * Q, where Q.x = public_key. + // sR = s * R, where R.x = r. + // and check that: + // zG +/- rQ = +/- sR, or more efficiently that: + // (zG +/- rQ).x = sR.x. + let (zG: EcPoint) = ec_mul(m=message, p=EcPoint(x=StarkCurve.GEN_X, y=StarkCurve.GEN_Y)); + let (public_key_point: EcPoint) = recover_y(public_key); + let (rQ: EcPoint) = ec_mul(signature_r, public_key_point); + let (signature_r_point: EcPoint) = recover_y(signature_r); + let (sR: EcPoint) = ec_mul(signature_s, signature_r_point); + + let (candidate: EcPoint) = ec_add(zG, rQ); + if (candidate.x == sR.x) { + return (res=TRUE); + } + + let (candidate: EcPoint) = ec_sub(zG, rQ); + if (candidate.x == sR.x) { + return (res=TRUE); + } + + return (res=FALSE); +} diff --git a/cairo/setup.py b/cairo/setup.py index e0317bc..376a1e6 100644 --- a/cairo/setup.py +++ b/cairo/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="sharp_p2p_bootloader", - version="0.1", + version="0.3", description="sharp_p2p bootloader", url="#", author="Okm165", @@ -14,6 +14,6 @@ "common.builtin_poseidon": ["*.cairo", "*/*.cairo"], "lang.compiler": ["cairo.ebnf", "lib/*.cairo"], "bootloader": ["*.cairo", "*/*.cairo"], - "bootloader.recursive_with_poseidon": ["*.cairo", "*/*.cairo"], + "bootloader.starknet": ["*.cairo", "*/*.cairo"], } ) diff --git a/compile.py b/compile.py index e9cb403..7e67489 100644 --- a/compile.py +++ b/compile.py @@ -5,5 +5,5 @@ current_dir = os.getcwd() log_and_run([ - f"cairo-compile --cairo_path=. bootloader/recursive_with_poseidon/simple_bootloader.cairo --output {current_dir}/bootloader.json --proof_mode", + f"cairo-compile --cairo_path=. bootloader/starknet/simple_bootloader.cairo --output {current_dir}/bootloader.json --proof_mode", ], "Compile bootloader program", cwd="cairo") \ No newline at end of file diff --git a/crates/common/src/layout.rs b/crates/common/src/layout.rs index 02153f0..be743f7 100644 --- a/crates/common/src/layout.rs +++ b/crates/common/src/layout.rs @@ -4,4 +4,5 @@ use strum::IntoStaticStr; #[strum(serialize_all = "snake_case")] pub enum Layout { RecursiveWithPoseidon, + Starknet, } diff --git a/crates/runner/src/cairo_runner/mod.rs b/crates/runner/src/cairo_runner/mod.rs index 1cbfc5d..86f8cbd 100644 --- a/crates/runner/src/cairo_runner/mod.rs +++ b/crates/runner/src/cairo_runner/mod.rs @@ -36,7 +36,7 @@ impl<'identity> RunnerController for CairoRunner<'identity> { let future: Pin> + '_>> = Box::pin(async move { let job_hash = hash!(job); - let layout: &str = Layout::RecursiveWithPoseidon.into(); + let layout: &str = Layout::Starknet.into(); let mut cairo_pie = NamedTempFile::new()?; cairo_pie.write_all(&job.job_data.cairo_pie_compressed)?; diff --git a/run.py b/run.py index 166fdb2..46d6635 100644 --- a/run.py +++ b/run.py @@ -4,7 +4,7 @@ log_and_run([ "cairo-run \ --program=bootloader.json \ - --layout=recursive_with_poseidon \ + --layout=starknet \ --program_input=bootloader_input.json \ --air_public_input=bootloader_public_input.json \ --air_private_input=bootloader_private_input.json \