Skip to content

Commit

Permalink
fix: Use memory instead of storage for returning leftovers in Zappers
Browse files Browse the repository at this point in the history
Plus move those functions to a common base contract.
  • Loading branch information
bingen committed Sep 9, 2024
1 parent 7eb4a51 commit 3f190ec
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 70 deletions.
33 changes: 33 additions & 0 deletions contracts/src/Zappers/LeftoversSweep.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.18;

import "openzeppelin-contracts/contracts/token/ERC20/IERC20.sol";

import "../Interfaces/IBoldToken.sol";

contract LeftoversSweep {
struct InitialBalances {
uint256 boldBalance;
uint256 collBalance;
address sender;
}

function _setInitialBalances(IERC20 _collToken, IBoldToken _boldToken, InitialBalances memory initialBalances) internal view {
initialBalances.boldBalance = _boldToken.balanceOf(address(this));
initialBalances.collBalance = _collToken.balanceOf(address(this));
initialBalances.sender = msg.sender;
}

function _returnLeftovers(IERC20 _collToken, IBoldToken _boldToken, InitialBalances memory initialBalances) internal {
uint256 currentCollBalance = _collToken.balanceOf(address(this));
if (currentCollBalance > initialBalances.collBalance) {
_collToken.transfer(initialBalances.sender, currentCollBalance - initialBalances.collBalance);
}
uint256 currentBoldBalance = _boldToken.balanceOf(address(this));
if (currentBoldBalance > initialBalances.boldBalance) {
_boldToken.transfer(initialBalances.sender, currentBoldBalance - initialBalances.boldBalance);
}
initialBalances.sender = address(0);
}
}
55 changes: 20 additions & 35 deletions contracts/src/Zappers/LeverageLSTZapper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import "../Interfaces/IBorrowerOperations.sol";
import "../Interfaces/IWETH.sol";
import "./GasCompZapper.sol";
import "../Dependencies/AddRemoveManagers.sol";
import "./LeftoversSweep.sol";
import "../Dependencies/Constants.sol";
import "./Interfaces/IFlashLoanProvider.sol";
import "./Interfaces/IFlashLoanReceiver.sol";
Expand All @@ -16,17 +17,13 @@ import "./Interfaces/ILeverageZapper.sol";

// import "forge-std/console2.sol";

contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper {
contract LeverageLSTZapper is GasCompZapper, LeftoversSweep, IFlashLoanReceiver, ILeverageZapper {
using SafeERC20 for IERC20;

IPriceFeed public immutable priceFeed;
IFlashLoanProvider public immutable flashLoanProvider;
IExchange public immutable exchange;

uint256 private initialBoldBalance;
uint256 private initialCollBalance;
address private initialSender;

constructor(IAddressesRegistry _addressesRegistry, IFlashLoanProvider _flashLoanProvider, IExchange _exchange)
GasCompZapper(_addressesRegistry)
{
Expand Down Expand Up @@ -60,9 +57,11 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper
require(msg.value == ETH_GAS_COMPENSATION, "LZ: Wrong ETH");

IERC20 collTokenCached = collToken;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(collTokenCached);
InitialBalances memory initialBalances;
_setInitialBalances(collTokenCached, boldTokenCached, initialBalances);

// Convert ETH to WETH
WETH.deposit{value: msg.value}();
Expand All @@ -74,6 +73,9 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper
flashLoanProvider.makeFlashLoan(
collTokenCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.OpenTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(collTokenCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -116,24 +118,26 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper

// Send coll back to return flash loan
vars.collToken.safeTransfer(address(flashLoanProvider), _params.flashLoanAmount);

// return leftovers to user
_returnLeftovers(vars.collToken, boldToken);
}

function leverUpTrove(LeverUpTroveParams calldata _params) external {
address owner = troveNFT.ownerOf(_params.troveId);
_requireSenderIsOwnerOrRemoveManagerAndGetReceiver(_params.troveId, owner);

IERC20 collTokenCached = collToken;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(collTokenCached);
InitialBalances memory initialBalances;
_setInitialBalances(collTokenCached, boldTokenCached, initialBalances);

// Flash loan coll
flashLoanProvider.makeFlashLoan(
collTokenCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverUpTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(collTokenCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -163,24 +167,26 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper

// Send coll back to return flash loan
collTokenCached.safeTransfer(address(flashLoanProvider), _params.flashLoanAmount);

// return leftovers to user
_returnLeftovers(collTokenCached, boldToken);
}

function leverDownTrove(LeverDownTroveParams calldata _params) external {
address owner = troveNFT.ownerOf(_params.troveId);
_requireSenderIsOwnerOrRemoveManagerAndGetReceiver(_params.troveId, owner);

IERC20 collTokenCached = collToken;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(collTokenCached);
InitialBalances memory initialBalances;
_setInitialBalances(collTokenCached, boldTokenCached, initialBalances);

// Flash loan coll
flashLoanProvider.makeFlashLoan(
collTokenCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverDownTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(collTokenCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -210,27 +216,6 @@ contract LeverageLSTZapper is GasCompZapper, IFlashLoanReceiver, ILeverageZapper

// Send coll back to return flash loan
collTokenCached.safeTransfer(address(flashLoanProvider), _params.flashLoanAmount);

// return leftovers to user
_returnLeftovers(collTokenCached, boldToken);
}

function _setInitialBalances(IERC20 _collToken) internal {
initialBoldBalance = boldToken.balanceOf(address(this));
initialCollBalance = _collToken.balanceOf(address(this));
initialSender = msg.sender;
}

function _returnLeftovers(IERC20 _collToken, IBoldToken _boldToken) internal {
uint256 currentCollBalance = _collToken.balanceOf(address(this));
if (currentCollBalance > initialCollBalance) {
_collToken.transfer(initialSender, currentCollBalance - initialCollBalance);
}
uint256 currentBoldBalance = _boldToken.balanceOf(address(this));
if (currentBoldBalance > initialBoldBalance) {
_boldToken.transfer(initialSender, currentBoldBalance - initialBoldBalance);
}
initialSender = address(0);
}

// As formulas are symmetrical, it can be used in both ways
Expand Down
55 changes: 20 additions & 35 deletions contracts/src/Zappers/LeverageWETHZapper.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import "../Interfaces/IBorrowerOperations.sol";
import "../Interfaces/IWETH.sol";
import "./WETHZapper.sol";
import "../Dependencies/AddRemoveManagers.sol";
import "./LeftoversSweep.sol";
import "../Dependencies/Constants.sol";
import "./Interfaces/IFlashLoanProvider.sol";
import "./Interfaces/IFlashLoanReceiver.sol";
Expand All @@ -14,15 +15,11 @@ import "./Interfaces/ILeverageZapper.sol";

// import "forge-std/console2.sol";

contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {
contract LeverageWETHZapper is WETHZapper, LeftoversSweep, IFlashLoanReceiver, ILeverageZapper {
IPriceFeed public immutable priceFeed;
IFlashLoanProvider public immutable flashLoanProvider;
IExchange public immutable exchange;

uint256 private initialBoldBalance;
uint256 private initialWETHBalance;
address private initialSender;

constructor(IAddressesRegistry _addressesRegistry, IFlashLoanProvider _flashLoanProvider, IExchange _exchange)
WETHZapper(_addressesRegistry)
{
Expand Down Expand Up @@ -54,9 +51,11 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {
require(msg.value == ETH_GAS_COMPENSATION + _params.collAmount, "LZ: Wrong amount of ETH");

IWETH WETHCached = WETH;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(WETHCached);
InitialBalances memory initialBalances;
_setInitialBalances(WETHCached, boldTokenCached, initialBalances);

// Convert ETH to WETH
WETH.deposit{value: msg.value}();
Expand All @@ -65,6 +64,9 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {
flashLoanProvider.makeFlashLoan(
WETHCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.OpenTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(WETHCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -109,24 +111,26 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {
// Send coll back to return flash loan
vars.WETH.transfer(address(flashLoanProvider), _params.flashLoanAmount);
// WETH reverts on failure: https://etherscan.io/token/0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2#code

// return leftovers to user
_returnLeftovers(vars.WETH, vars.boldToken);
}

function leverUpTrove(LeverUpTroveParams calldata _params) external {
address owner = troveNFT.ownerOf(_params.troveId);
_requireSenderIsOwnerOrRemoveManagerAndGetReceiver(_params.troveId, owner);

IWETH WETHCached = WETH;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(WETHCached);
InitialBalances memory initialBalances;
_setInitialBalances(WETHCached, boldTokenCached, initialBalances);

// Flash loan coll
flashLoanProvider.makeFlashLoan(
WETHCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverUpTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(WETHCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -156,24 +160,26 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {

// Send coll back to return flash loan
WETHCached.transfer(address(flashLoanProvider), _params.flashLoanAmount);

// return leftovers to user
_returnLeftovers(WETHCached, boldToken);
}

function leverDownTrove(LeverDownTroveParams calldata _params) external {
address owner = troveNFT.ownerOf(_params.troveId);
_requireSenderIsOwnerOrRemoveManagerAndGetReceiver(_params.troveId, owner);

IWETH WETHCached = WETH;
IBoldToken boldTokenCached = boldToken;

// Set initial balances to make sure there are not lefovers
_setInitialBalances(WETHCached);
InitialBalances memory initialBalances;
_setInitialBalances(WETHCached, boldTokenCached, initialBalances);

// Flash loan coll
flashLoanProvider.makeFlashLoan(
WETHCached, _params.flashLoanAmount, IFlashLoanProvider.Operation.LeverDownTrove, abi.encode(_params)
);

// return leftovers to user
_returnLeftovers(WETHCached, boldTokenCached, initialBalances);
}

// Callback from the flash loan provider
Expand Down Expand Up @@ -203,27 +209,6 @@ contract LeverageWETHZapper is WETHZapper, IFlashLoanReceiver, ILeverageZapper {

// Send coll back to return flash loan
WETHCached.transfer(address(flashLoanProvider), _params.flashLoanAmount);

// return leftovers to user
_returnLeftovers(WETHCached, boldToken);
}

function _setInitialBalances(IWETH _WETH) internal {
initialBoldBalance = boldToken.balanceOf(address(this));
initialWETHBalance = _WETH.balanceOf(address(this));
initialSender = msg.sender;
}

function _returnLeftovers(IWETH _WETH, IBoldToken _boldToken) internal {
uint256 currentWETHBalance = _WETH.balanceOf(address(this));
if (currentWETHBalance > initialWETHBalance) {
_WETH.transfer(initialSender, currentWETHBalance - initialWETHBalance);
}
uint256 currentBoldBalance = _boldToken.balanceOf(address(this));
if (currentBoldBalance > initialBoldBalance) {
_boldToken.transfer(initialSender, currentBoldBalance - initialBoldBalance);
}
initialSender = address(0);
}

// As formulas are symmetrical, it can be used in both ways
Expand Down

0 comments on commit 3f190ec

Please sign in to comment.