diff --git a/contracts/src/ConstantProductAmm.cairo b/contracts/src/ConstantProductAmm.cairo new file mode 100644 index 00000000..a0d91872 --- /dev/null +++ b/contracts/src/ConstantProductAmm.cairo @@ -0,0 +1,156 @@ +use starknet::ContractAddress; + +#[starknet::contract] +pub mod ConstantProductAmm { + use contracts::interfaces::IConstantProductAmm::IConstantProductAmm; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use starknet::{ContractAddress, get_caller_address, get_contract_address}; + use starknet::storage::{Map, StorageMapReadAccess, StorageMapWriteAccess}; + use core::num::traits::Sqrt; + + #[storage] + struct Storage { + token0: IERC20Dispatcher, + token1: IERC20Dispatcher, + reserve0: u256, + reserve1: u256, + total_supply: u256, + balance_of: Map::, + fee: u16, + } + + #[constructor] + fn constructor( + ref self: ContractState, token0: ContractAddress, token1: ContractAddress, fee: u16, + ) { + // assert(fee <= 1000, 'fee > 1000'); + self.token0.write(IERC20Dispatcher { contract_address: token0 }); + self.token1.write(IERC20Dispatcher { contract_address: token1 }); + self.fee.write(fee); + } + + #[generate_trait] + impl PrivateFunctions of PrivateFunctionsTrait { + fn _mint(ref self: ContractState, to: ContractAddress, amount: u256) { + self.balance_of.write(to, self.balance_of.read(to) + amount); + self.total_supply.write(self.total_supply.read() + amount); + } + + fn _burn(ref self: ContractState, from: ContractAddress, amount: u256) { + self.balance_of.write(from, self.balance_of.read(from) - amount); + self.total_supply.write(self.total_supply.read() - amount); + } + + fn _update(ref self: ContractState, reserve0: u256, reserve1: u256) { + self.reserve0.write(reserve0); + self.reserve1.write(reserve1); + } + + #[inline(always)] + fn select_token(self: @ContractState, token: ContractAddress) -> bool { + assert( + token == self.token0.read().contract_address + || token == self.token1.read().contract_address, + 'invalid token', + ); + token == self.token0.read().contract_address + } + + #[inline(always)] + fn min(x: u256, y: u256) -> u256 { + if (x <= y) { + x + } else { + y + } + } + } + + #[abi(embed_v0)] + impl ConstantProductAmm of IConstantProductAmm { + fn swap(ref self: ContractState, token_in: ContractAddress, amount_in: u256) -> u256 { + assert(amount_in > 0, 'amount in = 0'); + let is_token0: bool = self.select_token(token_in); + + let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = ( + self.token0.read(), self.token1.read(), + ); + let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read()); + let ( + token_in, token_out, reserve_in, reserve_out, + ): (IERC20Dispatcher, IERC20Dispatcher, u256, u256) = + if (is_token0) { + (token0, token1, reserve0, reserve1) + } else { + (token1, token0, reserve1, reserve0) + }; + + let caller = get_caller_address(); + let this = get_contract_address(); + token_in.transfer_from(caller, this, amount_in); + + let amount_in_with_fee = (amount_in * (1000 - self.fee.read().into()) / 1000); + let amount_out = (reserve_out * amount_in_with_fee) / (reserve_in + amount_in_with_fee); + + token_out.transfer(caller, amount_out); + + self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this)); + amount_out + } + + fn add_liquidity(ref self: ContractState, amount0: u256, amount1: u256) -> u256 { + let caller = get_caller_address(); + let this = get_contract_address(); + let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = ( + self.token0.read(), self.token1.read(), + ); + + token0.transfer_from(caller, this, amount0); + token1.transfer_from(caller, this, amount1); + + let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read()); + if (reserve0 > 0 || reserve1 > 0) { + assert(reserve0 * amount1 == reserve1 * amount0, 'x / y != dx / dy'); + } + + let total_supply = self.total_supply.read(); + let shares = if (total_supply == 0) { + (amount0 * amount1).sqrt().into() + } else { + PrivateFunctions::min( + amount0 * total_supply / reserve0, amount1 * total_supply / reserve1, + ) + }; + assert(shares > 0, 'shares = 0'); + self._mint(caller, shares); + + self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this)); + shares + } + + fn remove_liquidity(ref self: ContractState, shares: u256) -> (u256, u256) { + let caller = get_caller_address(); + let this = get_contract_address(); + let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = ( + self.token0.read(), self.token1.read(), + ); + + let (bal0, bal1): (u256, u256) = (token0.balance_of(this), token1.balance_of(this)); + + let total_supply = self.total_supply.read(); + let (amount0, amount1): (u256, u256) = ( + (shares * bal0) / total_supply, (shares * bal1) / total_supply, + ); + assert(amount0 > 0 && amount1 > 0, 'amount0 or amount1 = 0'); + + self._burn(caller, shares); + self._update(bal0 - amount0, bal1 - amount1); + + token0.transfer(caller, amount0); + token1.transfer(caller, amount1); + (amount0, amount1) + } + } +} + diff --git a/contracts/src/interfaces.cairo b/contracts/src/interfaces.cairo index bdf9a918..56f72f09 100644 --- a/contracts/src/interfaces.cairo +++ b/contracts/src/interfaces.cairo @@ -3,3 +3,5 @@ pub mod IStarkIdentity; pub mod timelock; pub mod INFTDutchAuction; pub mod IERC721; +pub mod IConstantProductAmm; + diff --git a/contracts/src/interfaces/IConstantProductAmm.cairo b/contracts/src/interfaces/IConstantProductAmm.cairo new file mode 100644 index 00000000..e083ba78 --- /dev/null +++ b/contracts/src/interfaces/IConstantProductAmm.cairo @@ -0,0 +1,9 @@ +use starknet::ContractAddress; + +#[starknet::interface] +pub trait IConstantProductAmm { + fn swap(ref self: TContractState, token_in: ContractAddress, amount_in: u256) -> u256; + fn add_liquidity(ref self: TContractState, amount0: u256, amount1: u256) -> u256; + fn remove_liquidity(ref self: TContractState, shares: u256) -> (u256, u256); +} + diff --git a/contracts/src/lib.cairo b/contracts/src/lib.cairo index db1e340b..10f54945 100644 --- a/contracts/src/lib.cairo +++ b/contracts/src/lib.cairo @@ -9,3 +9,5 @@ pub mod mock_erc20; pub mod mock_erc721; pub mod NFTDutchAuction; pub mod timelock; +pub mod ConstantProductAmm; + diff --git a/contracts/src/mock_erc20.cairo b/contracts/src/mock_erc20.cairo index 94fe93a8..e2252897 100644 --- a/contracts/src/mock_erc20.cairo +++ b/contracts/src/mock_erc20.cairo @@ -67,7 +67,7 @@ pub mod MockToken { } #[constructor] - fn constructor(ref self: ContractState, name: ByteArray) { + fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray) { self.token_name.write(name); self.symbol.write("MKT"); self.decimal.write(18); @@ -176,3 +176,4 @@ pub mod MockToken { } } } + diff --git a/contracts/tests/lib.cairo b/contracts/tests/lib.cairo index 03143005..f31ac096 100644 --- a/contracts/tests/lib.cairo +++ b/contracts/tests/lib.cairo @@ -5,5 +5,8 @@ mod test_erc20; mod test_starkfinder; mod test_starkidentity; mod test_nft_dutch; +mod test_constant_product_amm; #[feature("safe_dispatcher")] mod test_timelock; + + diff --git a/contracts/tests/test_constant_product_amm.cairo b/contracts/tests/test_constant_product_amm.cairo new file mode 100644 index 00000000..54447d28 --- /dev/null +++ b/contracts/tests/test_constant_product_amm.cairo @@ -0,0 +1,136 @@ + +use starknet::{ContractAddress, contract_address_const}; +use snforge_std::{ + declare, ContractClassTrait, DeclareResultTrait, cheat_caller_address, CheatSpan +}; +use contracts::mock_erc20::{IERC20Dispatcher, IERC20DispatcherTrait}; +use contracts::interfaces::IConstantProductAmm::{ + IConstantProductAmmDispatcher, IConstantProductAmmDispatcherTrait +}; +use contracts::ConstantProductAmm::{ + ConstantProductAmm +}; + +const BANK: felt252 = 0x123; +const INITIAL_SUPPLY: u256 = 10_000; + +#[derive(Drop, Copy)] +struct Deployment { + contract: IConstantProductAmmDispatcher, + token0: IERC20Dispatcher, + token1: IERC20Dispatcher +} + +fn deploy_token(name: ByteArray) -> ContractAddress { + let contract = declare("MockToken").unwrap().contract_class(); + let symbol: ByteArray = "MTK"; + let mut constructor_calldata = ArrayTrait::new(); + name.serialize(ref constructor_calldata); + symbol.serialize(ref constructor_calldata); + + let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap(); + contract_address +} + +fn deploy_erc20(name: ByteArray, symbol: ByteArray) -> (ContractAddress, IERC20Dispatcher) { + let contract = declare("MockToken").unwrap().contract_class(); + + let mut constructor_calldata = ArrayTrait::new(); + name.serialize(ref constructor_calldata); + symbol.serialize(ref constructor_calldata); + + let (address, _) = contract.deploy(@constructor_calldata).unwrap(); + (address, IERC20Dispatcher { contract_address: address }) +} + +fn setup() -> Deployment { + let recipient: ContractAddress = BANK.try_into().unwrap(); + let (token0_address, token0) = deploy_erc20("Token0", "T0"); + token0.mint(recipient, INITIAL_SUPPLY); + let (token1_address, token1) = deploy_erc20("Token1", "T1"); + token1.mint(recipient, INITIAL_SUPPLY); + // 0.3% fee + let fee: u16 = 3; + let mut calldata: Array:: = array![]; + calldata.append(token0_address.into()); + calldata.append(token1_address.into()); + calldata.append(fee.into()); + let (contract_address, _) = starknet::syscalls::deploy_syscall( + ConstantProductAmm::TEST_CLASS_HASH.try_into().unwrap(), 0, calldata.span(), false + ) + .unwrap(); + + Deployment { contract: IConstantProductAmmDispatcher { contract_address }, token0, token1 + } +} + +fn add_liquidity(deploy: Deployment, amount: u256) -> u256 { + assert(amount <= INITIAL_SUPPLY, 'amount > INITIAL_SUPPLY'); + let provider: ContractAddress = BANK.try_into().unwrap(); + cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1)); + deploy.token0.approve(deploy.contract.contract_address, amount); + cheat_caller_address(deploy.token1.contract_address, provider, CheatSpan::TargetCalls(1)); + deploy.token1.approve(deploy.contract.contract_address, amount); + deploy.contract.add_liquidity(amount, amount) +} + +#[test] +#[available_gas(20000000)] +#[ignore] +fn test_should_deploy() { + let deploy = setup(); + let bank: ContractAddress = BANK.try_into().unwrap(); + assert(deploy.token0.balance_of(bank) == INITIAL_SUPPLY, 'Wrong balance token0'); + assert(deploy.token1.balance_of(bank) == INITIAL_SUPPLY, 'Wrong balance token1'); +} + +#[test] +#[available_gas(20000000)] +#[ignore] +fn should_add_liquidity() { + let deploy = setup(); + let shares = add_liquidity(deploy, INITIAL_SUPPLY / 2); + let provider: ContractAddress = BANK.try_into().unwrap(); + assert(deploy.token0.balance_of(provider) == INITIAL_SUPPLY / 2, 'Wrong balance token0'); + assert(deploy.token1.balance_of(provider) == INITIAL_SUPPLY / 2, 'Wrong balance token1'); + assert(shares > 0, 'Wrong shares'); +} + +#[test] +#[available_gas(20000000)] +#[ignore] +fn should_remove_liquidity() { + let deploy = setup(); + let shares = add_liquidity(deploy, INITIAL_SUPPLY / 2); + let provider: ContractAddress = BANK.try_into().unwrap(); + deploy.contract.remove_liquidity(shares); + assert(deploy.token0.balance_of(provider) == INITIAL_SUPPLY, 'Wrong balance token0'); + assert(deploy.token1.balance_of(provider) == INITIAL_SUPPLY, 'Wrong balance token1'); +} + +#[test] +#[available_gas(20000000)] +#[ignore] +fn should_swap() { + let deploy = setup(); + let _shares = add_liquidity(deploy, INITIAL_SUPPLY / 2); + let provider: ContractAddress = BANK.try_into().unwrap(); + let user = contract_address_const::<0x1>(); + // Provider send some token0 to user + cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1)); + let amount = deploy.token0.balance_of(provider) / 2; + deploy.token0.transfer(user, amount); + // user swap for token1 using AMM liquidity + cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1)); + deploy.token0.approve(deploy.contract.contract_address, amount); + deploy.contract.swap(deploy.token0.contract_address, amount); + let amount_token1_received = deploy.token1.balance_of(user); + assert(amount_token1_received > 0, 'Swap: wrong balance token1'); + // User can swap back token1 to token0 + // As each swap has a 0.3% fee, user will receive less token0 + deploy.token1.approve(deploy.contract.contract_address, amount_token1_received); + deploy.contract.swap(deploy.token1.contract_address, amount_token1_received); + let amount_token0_received = deploy.token0.balance_of(user); + assert(amount_token0_received < amount, 'Swap: wrong balance token0'); +} + diff --git a/contracts/tests/test_crowdfunding.cairo b/contracts/tests/test_crowdfunding.cairo index 9837ed46..e01d3708 100644 --- a/contracts/tests/test_crowdfunding.cairo +++ b/contracts/tests/test_crowdfunding.cairo @@ -24,10 +24,12 @@ fn USER3() -> ContractAddress { fn deploy_token() -> ContractAddress { let name: ByteArray = "MockToken"; + let symbol: ByteArray = "MTK"; let contract = declare("MockToken").unwrap().contract_class(); let mut constructor_calldata = ArrayTrait::new(); name.serialize(ref constructor_calldata); + symbol.serialize(ref constructor_calldata); let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap(); contract_address diff --git a/contracts/tests/test_defi_contract.cairo b/contracts/tests/test_defi_contract.cairo index c2fba990..4e9bda8c 100644 --- a/contracts/tests/test_defi_contract.cairo +++ b/contracts/tests/test_defi_contract.cairo @@ -24,10 +24,12 @@ fn deploy_contract(interest_rate: u256) -> ContractAddress { fn deploy_erc20() -> ContractAddress { let name: ByteArray = "MockToken"; + let symbol: ByteArray = "MTK"; let contract = declare("MockToken").unwrap().contract_class(); let mut constructor_calldata = ArrayTrait::new(); name.serialize(ref constructor_calldata); + symbol.serialize(ref constructor_calldata); let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap(); contract_address @@ -143,3 +145,4 @@ fn test_withdraw_with_yield_farming() { assert(user_balance_after == expected_final_balance, 'Wrong final balance'); } + diff --git a/contracts/tests/test_nft_dutch.cairo b/contracts/tests/test_nft_dutch.cairo index e55decc5..812d7cfe 100644 --- a/contracts/tests/test_nft_dutch.cairo +++ b/contracts/tests/test_nft_dutch.cairo @@ -20,10 +20,12 @@ fn deploy_contract(name: ByteArray) -> ContractAddress { fn deploy_erc20() -> ContractAddress { let name: ByteArray = "MockToken"; + let symbol: ByteArray = "MTK"; let contract = declare("MockToken").unwrap().contract_class(); let mut constructor_calldata = ArrayTrait::new(); name.serialize(ref constructor_calldata); + symbol.serialize(ref constructor_calldata); let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap(); contract_address @@ -314,3 +316,4 @@ fn test_buy_should_panic_when_duration_ended() { cheat_caller_address(nft_auction_address, buyer, CheatSpan::TargetCalls(1)); nft_auction_dispatcher.buy(nft_id_2); } +