Skip to content

Commit

Permalink
feat: implement constant product amm contract
Browse files Browse the repository at this point in the history
  • Loading branch information
1nonlypiece committed Feb 27, 2025
1 parent dab3f53 commit ca99ec0
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 1 deletion.
156 changes: 156 additions & 0 deletions contracts/src/ConstantProductAmm.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use starknet::ContractAddress;

#[starknet::contract]
pub mod ConstantProductAmm {
use contracts::interfaces::IConstantProductAmm::IConstantProductAmm;
use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess};
use openzeppelin::token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait};
use starknet::{ContractAddress, get_caller_address, get_contract_address};
use starknet::storage::{Map, StorageMapReadAccess, StorageMapWriteAccess};
use core::num::traits::Sqrt;

#[storage]
struct Storage {
token0: IERC20Dispatcher,
token1: IERC20Dispatcher,
reserve0: u256,
reserve1: u256,
total_supply: u256,
balance_of: Map::<ContractAddress, u256>,
fee: u16,
}

#[constructor]
fn constructor(
ref self: ContractState, token0: ContractAddress, token1: ContractAddress, fee: u16,
) {
// assert(fee <= 1000, 'fee > 1000');
self.token0.write(IERC20Dispatcher { contract_address: token0 });
self.token1.write(IERC20Dispatcher { contract_address: token1 });
self.fee.write(fee);
}

#[generate_trait]
impl PrivateFunctions of PrivateFunctionsTrait {
fn _mint(ref self: ContractState, to: ContractAddress, amount: u256) {
self.balance_of.write(to, self.balance_of.read(to) + amount);
self.total_supply.write(self.total_supply.read() + amount);
}

fn _burn(ref self: ContractState, from: ContractAddress, amount: u256) {
self.balance_of.write(from, self.balance_of.read(from) - amount);
self.total_supply.write(self.total_supply.read() - amount);
}

fn _update(ref self: ContractState, reserve0: u256, reserve1: u256) {
self.reserve0.write(reserve0);
self.reserve1.write(reserve1);
}

#[inline(always)]
fn select_token(self: @ContractState, token: ContractAddress) -> bool {
assert(
token == self.token0.read().contract_address
|| token == self.token1.read().contract_address,
'invalid token',
);
token == self.token0.read().contract_address
}

#[inline(always)]
fn min(x: u256, y: u256) -> u256 {
if (x <= y) {
x
} else {
y
}
}
}

#[abi(embed_v0)]
impl ConstantProductAmm of IConstantProductAmm<ContractState> {
fn swap(ref self: ContractState, token_in: ContractAddress, amount_in: u256) -> u256 {
assert(amount_in > 0, 'amount in = 0');
let is_token0: bool = self.select_token(token_in);

let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
self.token0.read(), self.token1.read(),
);
let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read());
let (
token_in, token_out, reserve_in, reserve_out,
): (IERC20Dispatcher, IERC20Dispatcher, u256, u256) =
if (is_token0) {
(token0, token1, reserve0, reserve1)
} else {
(token1, token0, reserve1, reserve0)
};

let caller = get_caller_address();
let this = get_contract_address();
token_in.transfer_from(caller, this, amount_in);

let amount_in_with_fee = (amount_in * (1000 - self.fee.read().into()) / 1000);
let amount_out = (reserve_out * amount_in_with_fee) / (reserve_in + amount_in_with_fee);

token_out.transfer(caller, amount_out);

self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this));
amount_out
}

fn add_liquidity(ref self: ContractState, amount0: u256, amount1: u256) -> u256 {
let caller = get_caller_address();
let this = get_contract_address();
let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
self.token0.read(), self.token1.read(),
);

token0.transfer_from(caller, this, amount0);
token1.transfer_from(caller, this, amount1);

let (reserve0, reserve1): (u256, u256) = (self.reserve0.read(), self.reserve1.read());
if (reserve0 > 0 || reserve1 > 0) {
assert(reserve0 * amount1 == reserve1 * amount0, 'x / y != dx / dy');
}

let total_supply = self.total_supply.read();
let shares = if (total_supply == 0) {
(amount0 * amount1).sqrt().into()
} else {
PrivateFunctions::min(
amount0 * total_supply / reserve0, amount1 * total_supply / reserve1,
)
};
assert(shares > 0, 'shares = 0');
self._mint(caller, shares);

self._update(self.token0.read().balance_of(this), self.token1.read().balance_of(this));
shares
}

fn remove_liquidity(ref self: ContractState, shares: u256) -> (u256, u256) {
let caller = get_caller_address();
let this = get_contract_address();
let (token0, token1): (IERC20Dispatcher, IERC20Dispatcher) = (
self.token0.read(), self.token1.read(),
);

let (bal0, bal1): (u256, u256) = (token0.balance_of(this), token1.balance_of(this));

let total_supply = self.total_supply.read();
let (amount0, amount1): (u256, u256) = (
(shares * bal0) / total_supply, (shares * bal1) / total_supply,
);
assert(amount0 > 0 && amount1 > 0, 'amount0 or amount1 = 0');

self._burn(caller, shares);
self._update(bal0 - amount0, bal1 - amount1);

token0.transfer(caller, amount0);
token1.transfer(caller, amount1);
(amount0, amount1)
}
}
}

2 changes: 2 additions & 0 deletions contracts/src/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ pub mod IStarkIdentity;
pub mod timelock;
pub mod INFTDutchAuction;
pub mod IERC721;
pub mod IConstantProductAmm;

9 changes: 9 additions & 0 deletions contracts/src/interfaces/IConstantProductAmm.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use starknet::ContractAddress;

#[starknet::interface]
pub trait IConstantProductAmm<TContractState> {
fn swap(ref self: TContractState, token_in: ContractAddress, amount_in: u256) -> u256;
fn add_liquidity(ref self: TContractState, amount0: u256, amount1: u256) -> u256;
fn remove_liquidity(ref self: TContractState, shares: u256) -> (u256, u256);
}

2 changes: 2 additions & 0 deletions contracts/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pub mod mock_erc20;
pub mod mock_erc721;
pub mod NFTDutchAuction;
pub mod timelock;
pub mod ConstantProductAmm;

3 changes: 2 additions & 1 deletion contracts/src/mock_erc20.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub mod MockToken {
}

#[constructor]
fn constructor(ref self: ContractState, name: ByteArray) {
fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray) {
self.token_name.write(name);
self.symbol.write("MKT");
self.decimal.write(18);
Expand Down Expand Up @@ -176,3 +176,4 @@ pub mod MockToken {
}
}
}

3 changes: 3 additions & 0 deletions contracts/tests/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@ mod test_erc20;
mod test_starkfinder;
mod test_starkidentity;
mod test_nft_dutch;
mod test_constant_product_amm;
#[feature("safe_dispatcher")]
mod test_timelock;


136 changes: 136 additions & 0 deletions contracts/tests/test_constant_product_amm.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@

use starknet::{ContractAddress, contract_address_const};
use snforge_std::{
declare, ContractClassTrait, DeclareResultTrait, cheat_caller_address, CheatSpan
};
use contracts::mock_erc20::{IERC20Dispatcher, IERC20DispatcherTrait};
use contracts::interfaces::IConstantProductAmm::{
IConstantProductAmmDispatcher, IConstantProductAmmDispatcherTrait
};
use contracts::ConstantProductAmm::{
ConstantProductAmm
};

const BANK: felt252 = 0x123;
const INITIAL_SUPPLY: u256 = 10_000;

#[derive(Drop, Copy)]
struct Deployment {
contract: IConstantProductAmmDispatcher,
token0: IERC20Dispatcher,
token1: IERC20Dispatcher
}

fn deploy_token(name: ByteArray) -> ContractAddress {
let contract = declare("MockToken").unwrap().contract_class();
let symbol: ByteArray = "MTK";
let mut constructor_calldata = ArrayTrait::new();
name.serialize(ref constructor_calldata);
symbol.serialize(ref constructor_calldata);

let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap();
contract_address
}

fn deploy_erc20(name: ByteArray, symbol: ByteArray) -> (ContractAddress, IERC20Dispatcher) {
let contract = declare("MockToken").unwrap().contract_class();

let mut constructor_calldata = ArrayTrait::new();
name.serialize(ref constructor_calldata);
symbol.serialize(ref constructor_calldata);

let (address, _) = contract.deploy(@constructor_calldata).unwrap();
(address, IERC20Dispatcher { contract_address: address })
}

fn setup() -> Deployment {
let recipient: ContractAddress = BANK.try_into().unwrap();
let (token0_address, token0) = deploy_erc20("Token0", "T0");
token0.mint(recipient, INITIAL_SUPPLY);
let (token1_address, token1) = deploy_erc20("Token1", "T1");
token1.mint(recipient, INITIAL_SUPPLY);
// 0.3% fee
let fee: u16 = 3;
let mut calldata: Array::<felt252> = array![];
calldata.append(token0_address.into());
calldata.append(token1_address.into());
calldata.append(fee.into());
let (contract_address, _) = starknet::syscalls::deploy_syscall(
ConstantProductAmm::TEST_CLASS_HASH.try_into().unwrap(), 0, calldata.span(), false
)
.unwrap();

Deployment { contract: IConstantProductAmmDispatcher { contract_address }, token0, token1
}
}

fn add_liquidity(deploy: Deployment, amount: u256) -> u256 {
assert(amount <= INITIAL_SUPPLY, 'amount > INITIAL_SUPPLY');
let provider: ContractAddress = BANK.try_into().unwrap();
cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1));
deploy.token0.approve(deploy.contract.contract_address, amount);
cheat_caller_address(deploy.token1.contract_address, provider, CheatSpan::TargetCalls(1));
deploy.token1.approve(deploy.contract.contract_address, amount);
deploy.contract.add_liquidity(amount, amount)
}

#[test]
#[available_gas(20000000)]
#[ignore]
fn test_should_deploy() {
let deploy = setup();
let bank: ContractAddress = BANK.try_into().unwrap();
assert(deploy.token0.balance_of(bank) == INITIAL_SUPPLY, 'Wrong balance token0');
assert(deploy.token1.balance_of(bank) == INITIAL_SUPPLY, 'Wrong balance token1');
}

#[test]
#[available_gas(20000000)]
#[ignore]
fn should_add_liquidity() {
let deploy = setup();
let shares = add_liquidity(deploy, INITIAL_SUPPLY / 2);
let provider: ContractAddress = BANK.try_into().unwrap();
assert(deploy.token0.balance_of(provider) == INITIAL_SUPPLY / 2, 'Wrong balance token0');
assert(deploy.token1.balance_of(provider) == INITIAL_SUPPLY / 2, 'Wrong balance token1');
assert(shares > 0, 'Wrong shares');
}

#[test]
#[available_gas(20000000)]
#[ignore]
fn should_remove_liquidity() {
let deploy = setup();
let shares = add_liquidity(deploy, INITIAL_SUPPLY / 2);
let provider: ContractAddress = BANK.try_into().unwrap();
deploy.contract.remove_liquidity(shares);
assert(deploy.token0.balance_of(provider) == INITIAL_SUPPLY, 'Wrong balance token0');
assert(deploy.token1.balance_of(provider) == INITIAL_SUPPLY, 'Wrong balance token1');
}

#[test]
#[available_gas(20000000)]
#[ignore]
fn should_swap() {
let deploy = setup();
let _shares = add_liquidity(deploy, INITIAL_SUPPLY / 2);
let provider: ContractAddress = BANK.try_into().unwrap();
let user = contract_address_const::<0x1>();
// Provider send some token0 to user
cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1));
let amount = deploy.token0.balance_of(provider) / 2;
deploy.token0.transfer(user, amount);
// user swap for token1 using AMM liquidity
cheat_caller_address(deploy.token0.contract_address, provider, CheatSpan::TargetCalls(1));
deploy.token0.approve(deploy.contract.contract_address, amount);
deploy.contract.swap(deploy.token0.contract_address, amount);
let amount_token1_received = deploy.token1.balance_of(user);
assert(amount_token1_received > 0, 'Swap: wrong balance token1');
// User can swap back token1 to token0
// As each swap has a 0.3% fee, user will receive less token0
deploy.token1.approve(deploy.contract.contract_address, amount_token1_received);
deploy.contract.swap(deploy.token1.contract_address, amount_token1_received);
let amount_token0_received = deploy.token0.balance_of(user);
assert(amount_token0_received < amount, 'Swap: wrong balance token0');
}

2 changes: 2 additions & 0 deletions contracts/tests/test_crowdfunding.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ fn USER3() -> ContractAddress {

fn deploy_token() -> ContractAddress {
let name: ByteArray = "MockToken";
let symbol: ByteArray = "MTK";
let contract = declare("MockToken").unwrap().contract_class();

let mut constructor_calldata = ArrayTrait::new();
name.serialize(ref constructor_calldata);
symbol.serialize(ref constructor_calldata);

let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap();
contract_address
Expand Down
3 changes: 3 additions & 0 deletions contracts/tests/test_defi_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ fn deploy_contract(interest_rate: u256) -> ContractAddress {

fn deploy_erc20() -> ContractAddress {
let name: ByteArray = "MockToken";
let symbol: ByteArray = "MTK";
let contract = declare("MockToken").unwrap().contract_class();

let mut constructor_calldata = ArrayTrait::new();
name.serialize(ref constructor_calldata);
symbol.serialize(ref constructor_calldata);

let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap();
contract_address
Expand Down Expand Up @@ -143,3 +145,4 @@ fn test_withdraw_with_yield_farming() {
assert(user_balance_after == expected_final_balance, 'Wrong final balance');
}


3 changes: 3 additions & 0 deletions contracts/tests/test_nft_dutch.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ fn deploy_contract(name: ByteArray) -> ContractAddress {

fn deploy_erc20() -> ContractAddress {
let name: ByteArray = "MockToken";
let symbol: ByteArray = "MTK";
let contract = declare("MockToken").unwrap().contract_class();

let mut constructor_calldata = ArrayTrait::new();
name.serialize(ref constructor_calldata);
symbol.serialize(ref constructor_calldata);

let (contract_address, _) = contract.deploy(@constructor_calldata).unwrap();
contract_address
Expand Down Expand Up @@ -314,3 +316,4 @@ fn test_buy_should_panic_when_duration_ended() {
cheat_caller_address(nft_auction_address, buyer, CheatSpan::TargetCalls(1));
nft_auction_dispatcher.buy(nft_id_2);
}

0 comments on commit ca99ec0

Please sign in to comment.