Skip to content

Commit

Permalink
♻️ Fix Initializable (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorized authored Jan 4, 2024
1 parent ebdd90f commit 4bed140
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 109 deletions.
21 changes: 8 additions & 13 deletions .gas-snapshot
Original file line number Diff line number Diff line change
Expand Up @@ -498,19 +498,14 @@ FixedPointMathLibTest:testZeroFloorSubCasted(uint32,uint32,uint256) (runs: 256,
FixedPointMathLibTest:test__codesize() (gas: 42329)
GasBurnerLibTest:testBurnGas() (gas: 1700805)
GasBurnerLibTest:test__codesize() (gas: 1435)
InitializableTest:testInit() (gas: 51899)
InitializableTest:testInit(uint256) (runs: 256, μ: 51312, ~: 51971)
InitializableTest:testInitRevertWithInvalidInitialization(uint256) (runs: 256, μ: 54181, ~: 54840)
InitializableTest:testInitializeInititalizerTrick(bool,uint64,uint64) (runs: 256, μ: 816, ~: 817)
InitializableTest:testInitializeIsDisabled() (gas: 171942)
InitializableTest:testInitializeReinititalizerTrick(bool,uint64,uint64) (runs: 256, μ: 662, ~: 675)
InitializableTest:testOnlyInitializing() (gas: 55672)
InitializableTest:testReinitialize() (gas: 1391019)
InitializableTest:testReinitializeRevertWhenContractIsInitializing(uint256,uint64) (runs: 256, μ: 243918, ~: 244733)
InitializableTest:testReinitializeRevertWithInvalidInitialization(uint64,uint64) (runs: 256, μ: 57642, ~: 57697)
InitializableTest:testRevertWhenCalledOnlyInitializingFunctionWithNonInitializer() (gas: 131405)
InitializableTest:testRevertWhenInitializeIsDisabled(uint256) (runs: 256, μ: 174784, ~: 174784)
InitializableTest:test__codesize() (gas: 9133)
InitializableTest:testDisableInitializers() (gas: 41951)
InitializableTest:testInitializableConstructor() (gas: 690140)
InitializableTest:testInitialize() (gas: 54719)
InitializableTest:testInitializeInititalizerTrick(bool,uint64,uint16) (runs: 256, μ: 791, ~: 791)
InitializableTest:testInitializeReinititalize(uint256) (runs: 256, μ: 94373, ~: 92813)
InitializableTest:testInitializeReinititalizerTrick(bool,uint64,uint64) (runs: 256, μ: 674, ~: 687)
InitializableTest:testOnlyInitializing() (gas: 10417)
InitializableTest:test__codesize() (gas: 11968)
JSONParserLibTest:testDecodeEncodedStringDoesNotRevert(string) (runs: 256, μ: 58291, ~: 57099)
JSONParserLibTest:testDecodeInvalidStringReverts() (gas: 174221)
JSONParserLibTest:testDecodeString() (gas: 201120)
Expand Down
10 changes: 5 additions & 5 deletions src/utils/Initializable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ abstract contract Initializable {
assembly {
let s := _INITIALIZABLE_SLOT
i := sload(s)
// If `!(initializing && initializedVersion == 0)`.
// If `!(initializing == 0 && initializedVersion == 0)`.
if i {
// If `!(codesize == 0 && initializedVersion == 1)`.
if iszero(lt(codesize(), eq(shr(1, i), 1))) {
// If `!(address(this).code.length == 0 && initializedVersion == 1)`.
if iszero(lt(extcodesize(address()), eq(shr(1, i), 1))) {
mstore(0x00, 0xf92ee8a9) // `InvalidInitialization()`.
revert(0x1c, 0x04)
}
Expand All @@ -73,7 +73,7 @@ abstract contract Initializable {
_;
/// @solidity memory-safe-assembly
assembly {
// If `initializing`.
// If `initializing == 0`.
if iszero(and(i, 1)) {
// Set `initializing` to 0, `initializedVersion` to 1.
sstore(_INITIALIZABLE_SLOT, 2)
Expand All @@ -96,7 +96,7 @@ abstract contract Initializable {
version := and(version, 0xffffffffffffffff) // Clean upper bits.
let s := _INITIALIZABLE_SLOT
let i := sload(s)
// If `initializing || initializedVersion >= version`.
// If `initializing == 1 || initializedVersion >= version`.
if iszero(lt(and(i, 1), lt(shr(1, i), version))) {
mstore(0x00, 0xf92ee8a9) // `InvalidInitialization()`.
revert(0x1c, 0x04)
Expand Down
151 changes: 98 additions & 53 deletions test/Initializable.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,96 +2,141 @@
pragma solidity ^0.8.4;

import "./utils/SoladyTest.sol";
import {
MockInitializable,
MockInitializableRevert,
MockInitializableDisabled,
MockInitializableRevert2,
Initializable
} from "./utils/mocks/MockInitializable.sol";
import {MockInitializable, Initializable} from "./utils/mocks/MockInitializable.sol";

contract InitializableTest is SoladyTest {
MockInitializable m1;
event Initialized(uint64 version);

MockInitializable m;

function setUp() public {
m1 = new MockInitializable();
MockInitializable.Args memory a;
m = new MockInitializable(a);
}

function testInit() public {
testInit(123);
function _args() internal returns (MockInitializable.Args memory a) {
a.x = _random();
a.version = uint64(_bound(_random(), 1, type(uint64).max));
a.checkOnlyDuringInitializing = _random() & 1 == 0;
a.recurse = _random() & 1 == 0;
}

function testInit(uint256 x) public {
m1.init(x);
assertEq(m1.x(), x);
function _expectEmitInitialized(uint64 version) internal {
vm.expectEmit(true, true, true, true);
emit Initialized(version);
}

function testOnlyInitializing() public {
testInit(123);
vm.expectRevert(Initializable.NotInitializing.selector);
m1.onlyDuringInitializing();
function testInitialize() public {
MockInitializable.Args memory a;
a.x = 123;
m.initialize(a);
assertEq(m.x(), a.x);
_checkVersion(1);
}

function testInitRevertWithInvalidInitialization(uint256 x) public {
m1.init(x);
vm.expectRevert(Initializable.InvalidInitialization.selector);
m1.init(x);
function _checkVersion(uint64 version) internal {
assertEq(m.version(), version);
assertFalse(m.isInitializing());
}

function testReinitialize() public {
m1.init(5);
assertEq(m1.getVersion(), 1);
for (uint64 i = 2; i < 258; i++) {
m1.reinit(i + 5, i);
assertEq(m1.getVersion(), i);
assertEq(m1.x(), i + 5);
function testInitializeReinititalize(uint256) public {
MockInitializable.Args memory a = _args();

if (a.recurse) {
vm.expectRevert(Initializable.InvalidInitialization.selector);
if (_random() & 1 == 0) {
m.initialize(a);
} else {
m.reinitialize(a);
}
return;
}
}

function testReinitializeRevertWithInvalidInitialization(uint64 x_, uint64 version) public {
m1.init(x_);
m1.reinit(x_, type(uint64).max);
if (_random() & 1 == 0) {
_expectEmitInitialized(1);
m.initialize(a);
a.version = 1;
} else {
_expectEmitInitialized(a.version);
m.reinitialize(a);
}
assertEq(m.x(), a.x);
_checkVersion(a.version);

vm.expectRevert(Initializable.InvalidInitialization.selector);
m1.reinit(x_, version);
if (_random() & 1 == 0) {
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.initialize(a);
}
if (_random() & 1 == 0) {
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.reinitialize(a);
}
if (_random() & 1 == 0) {
a.version = m.version();
uint64 newVersion = uint64(_random());
if (newVersion > a.version) {
a.version = newVersion;
m.reinitialize(a);
_checkVersion(a.version);
}
}
}

function testReinitializeRevertWhenContractIsInitializing(uint256 x_, uint64 version) public {
MockInitializableRevert m2 = new MockInitializableRevert();
vm.expectRevert(Initializable.InvalidInitialization.selector);
m2.init1(x_, version);
function testOnlyInitializing() public {
vm.expectRevert(Initializable.NotInitializing.selector);
m.onlyDuringInitializing();
}

function testRevertWhenInitializeIsDisabled(uint256 x_) public {
MockInitializableDisabled m = new MockInitializableDisabled();
function testDisableInitializers() public {
_expectEmitInitialized(type(uint64).max);
m.disableInitializers();
_checkVersion(type(uint64).max);
m.disableInitializers();
_checkVersion(type(uint64).max);

MockInitializable.Args memory a;
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.init(x_);
m.initialize(a);
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.reinitialize(a);
}

function testInitializeIsDisabled() public {
MockInitializableDisabled m = new MockInitializableDisabled();
assertEq(m.getVersion(), type(uint64).max);
}
function testInitializableConstructor() public {
MockInitializable.Args memory a;
a.initializeMulti = true;
m = new MockInitializable(a);
_checkVersion(1);

function testRevertWhenCalledOnlyInitializingFunctionWithNonInitializer() public {
MockInitializableRevert2 m = new MockInitializableRevert2();
vm.expectRevert(Initializable.NotInitializing.selector);
m.init(5);
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.initialize(a);
a.version = 2;
m.reinitialize(a);
_checkVersion(2);

a.disableInitializers = true;
_expectEmitInitialized(type(uint64).max);
m = new MockInitializable(a);
_checkVersion(type(uint64).max);
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.initialize(a);
vm.expectRevert(Initializable.InvalidInitialization.selector);
m.reinitialize(a);
}

function testInitializeInititalizerTrick(
bool initializing,
uint64 initializedVersion,
uint64 codeSize
uint16 codeSize
) public {
bool isTopLevelCall = !initializing;
bool initialSetup = initializedVersion == 0 && isTopLevelCall;
bool construction = initializedVersion == 1 && codeSize == 0;
bool expected = !initialSetup && !construction;
bool computed;
uint256 i;
/// @solidity memory-safe-assembly
assembly {
let i := or(initializing, shl(1, initializedVersion))
i := or(initializing, shl(1, initializedVersion))
if i { if iszero(lt(codeSize, eq(shr(1, i), 1))) { computed := 1 } }
}
assertEq(computed, expected);
Expand All @@ -102,7 +147,7 @@ contract InitializableTest is SoladyTest {
uint64 initializedVersion,
uint64 version
) public {
bool expected = initializing || initializedVersion >= version;
bool expected = initializing == true || initializedVersion >= version;
bool computed;
/// @solidity memory-safe-assembly
assembly {
Expand Down
92 changes: 54 additions & 38 deletions test/utils/mocks/MockInitializable.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,60 +5,76 @@ import {Initializable} from "../../../src/utils/Initializable.sol";

/// @dev WARNING! This mock is strictly intended for testing purposes only.
/// Do NOT copy anything here into production code unless you really know what you are doing.
contract MockInitializableParent is Initializable {
contract MockInitializable is Initializable {
uint256 public x;
uint256 public y;

event Yo();

function _initialize(uint256 x_) internal onlyInitializing {
x = x_;
if (x_ & 8 == 0) onlyDuringInitializing();
struct Args {
uint256 x;
uint64 version;
bool disableInitializers;
bool initializeMulti;
bool checkOnlyDuringInitializing;
bool recurse;
}

function getVersion() external view returns (uint64) {
return _getInitializedVersion();
constructor(Args memory a) {
if (a.initializeMulti) {
require(_getInitializedVersion() == 0, "The version should be zero.");
require(!_isInitializing(), "Initializing should be false.");
initialize(a);
require(_getInitializedVersion() == 1, "The version should be one.");
require(!_isInitializing(), "Initializing should be false.");
initialize(a);
require(_getInitializedVersion() == 1, "The version should be one.");
require(!_isInitializing(), "Initializing should be false.");
}
if (a.disableInitializers) {
_disableInitializers();
}
}

function isInitializing() external view returns (bool) {
return _isInitializing();
function initialize(Args memory a) public initializer {
x = a.x;
if (a.checkOnlyDuringInitializing) {
onlyDuringInitializing();
}
if (a.recurse) {
a.recurse = false;
if (a.x & 1 == 0) initialize(a);
else reinitialize(a);
}
}

function onlyDuringInitializing() public onlyInitializing {
emit Yo();
function reinitialize(Args memory a) public reinitializer(a.version) {
x = a.x;
if (a.checkOnlyDuringInitializing) {
onlyDuringInitializing();
}
if (a.recurse) {
a.recurse = false;
if (a.x & 1 == 0) initialize(a);
else reinitialize(a);
}
}
}

contract MockInitializable is MockInitializableParent {
function init(uint256 x_) public initializer {
_initialize(x_);
function version() external view returns (uint64) {
return _getInitializedVersion();
}

function reinit(uint256 x_, uint64 version) public reinitializer(version) {
_initialize(x_);
function isInitializing() external view returns (bool) {
return _isInitializing();
}
}

contract MockInitializableRevert is MockInitializableParent {
function init1(uint256 x_, uint64 version) public initializer {
_initialize(x_);
reinit(version);
function onlyDuringInitializing() public onlyInitializing {
require(_getInitializedVersion() != 0, "The version should not be zero.");
require(_isInitializing(), "Initializing should be true.");
unchecked {
++y;
}
}

function reinit(uint64 version) public reinitializer(version) {}
}

contract MockInitializableDisabled is MockInitializableParent {
constructor() {
function disableInitializers() public {
_disableInitializers();
}

function init(uint256 x_) public initializer {
_initialize(x_);
}
}

contract MockInitializableRevert2 is MockInitializableParent {
function init(uint256 x_) public {
_initialize(x_);
}
}

0 comments on commit 4bed140

Please sign in to comment.