-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathm52.py
executable file
·160 lines (129 loc) · 5.28 KB
/
m52.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python3
"""Iterated Hash Function Multicollisions"""
# https://www.iacr.org/archive/crypto2004/31520306/multicollisions.pdf
from collections.abc import Iterator, Callable
from typing import NamedTuple
from functools import cache
from copy import copy
from abc import ABC
from Crypto.Cipher import AES
from m02 import fixed_xor
from m28 import HashBase, merkle_pad
BLOCKSIZE = 16
Chain = NamedTuple("Chain", [("input", bytes), ("out", bytes)])
HashCollision = NamedTuple("HashCollision",
[("messages", tuple[bytes, ...]), ("hash", Chain)])
class MDHash(HashBase, ABC):
"""Abstract HashBase with register property"""
block_size = BLOCKSIZE
digest_size = 20
register = bytes(digest_size)
def __init__(self, data: bytes = b"") -> None:
self._h = self.register
self.update(data)
def copy(self) -> "MDHash":
return copy(self)
def update(self, data: bytes) -> None:
self._h = md(data, self._h, aes_compressor)
def digest(self) -> bytes:
return self._h
def hexdigest(self) -> str:
return self._h.hex()
class CheapHash(MDHash):
"""Cheap MD hash"""
digest_size = 2
name = "cheaphash"
register = bytes(digest_size)
class ExpensiveHash(MDHash):
"""Less cheap MD hash"""
digest_size = 4
name = "expensivehash"
register = bytes(digest_size)
def cascade_hash(message: bytes) -> bytes:
"""Super secure collision-resistant cascading hash function"""
return CheapHash(message).digest() + ExpensiveHash(message).digest()
def blocks(m: bytes, blocksize: int = BLOCKSIZE) -> list[bytes]:
"""Split m into blocks"""
assert len(m) % blocksize == 0
return [m[i:i + blocksize] for i in range(0, len(m), blocksize)]
def pad(m: bytes) -> bytes:
"""Merkle-pad idempotently"""
# We don't pad if we already fit the block size because we want to share
# the md function between both the MDHash instances (which don't require
# idempotency) and the attacker (which does).
if len(m) % BLOCKSIZE == 0:
return m
return merkle_pad(m, BLOCKSIZE, "big", 4)
def aes_compressor(m: bytes, h: bytes) -> bytes:
"""Compress with AES-ECB-128"""
assert len(m) == BLOCKSIZE
return AES.new(pad(h), AES.MODE_ECB).encrypt(m)
@cache
def md(m: bytes, h: bytes,
c: Callable[[bytes, bytes], bytes] = aes_compressor) -> bytes:
"""Generic Merkle–Damgård compression function"""
digest_size = len(h)
for block in blocks(pad(m)):
h = fixed_xor(c(block, h)[:digest_size], h)
return h
def all_possible_block_pairs(byte_length: int) -> Iterator[tuple[bytes, bytes]]:
"""All unique block pairs of a given length"""
bit_length = 8 * byte_length
for m1_int in range(2 ** bit_length):
for m2_int in range(m1_int + 1, 2 ** bit_length):
yield m1_int.to_bytes(byte_length, "big"), \
m2_int.to_bytes(byte_length, "big")
def verify_collision(collision: HashCollision) -> bool:
"""Verify collisions for a single MD-style compression function"""
target_hashes = set()
for m in collision.messages:
target_hashes.add(md(m, collision.hash.input))
return len(target_hashes) == 1 and target_hashes.pop() == collision.hash.out
def generate_colliding_pairs(n: int, h: bytes) -> Iterator[HashCollision]:
"""Yield colliding pairs for a sequence of states"""
for _ in range(n):
for m, m_prime in all_possible_block_pairs(len(h)):
h_next = md(m, h)
if h_next == md(m_prime, h):
collision = HashCollision((m, m_prime), Chain(h, h_next))
break
h = collision.hash.out
yield collision
def generate_multicollision(n: int, hasher: type[MDHash]) -> HashCollision:
"""Return 2ⁿ multicollisions"""
h = hasher.register
messages = [b""]
for collision in generate_colliding_pairs(n, h):
assert verify_collision(collision)
h_out = collision.hash.out
for message in copy(messages):
messages.remove(message)
messages += [message + pad(collision.messages[i]) for i in range(2)]
return HashCollision(tuple(messages), Chain(h, h_out))
def find_cascading_hash_collision(limit: int = 20) -> HashCollision:
"""Find a collision in the cascading hash function"""
for n in range(1, limit):
print(f"Adding {2 ** n} hashes to the pool")
xh_map: dict[bytes, bytes] = {}
multicollision = generate_multicollision(n, CheapHash)
for m in multicollision.messages:
xh = ExpensiveHash(m).digest()
if xh in xh_map:
return HashCollision((m, xh_map[xh]),
Chain(None, multicollision.hash.out + xh))
xh_map[xh] = m
raise RuntimeError("Failed to find a collision")
def main() -> None:
"""Entry point"""
collision = find_cascading_hash_collision(limit=20)
target_hashes = set()
for m in collision.messages:
target_hashes.add(cascade_hash(m))
assert len(target_hashes) == 1
assert len(set(collision.messages)) == 2
assert target_hashes.pop() == collision.hash.out
print("m₀ =", collision.messages[0].hex())
print("m₁ =", collision.messages[1].hex())
print("h =", collision.hash.out.hex())
if __name__ == "__main__":
main()