Skip to content

Commit

Permalink
feat: add upgradable contract
Browse files Browse the repository at this point in the history
  • Loading branch information
greatest0fallt1me authored and PoulavBhowmick03 committed Feb 27, 2025
1 parent 014b84a commit c888e09
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 1 deletion.
1 change: 1 addition & 0 deletions contracts/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pub mod mock_erc721;
pub mod NFTDutchAuction;
pub mod timelock;
pub mod ConstantProductAmm;
pub mod upgradable;

134 changes: 134 additions & 0 deletions contracts/src/upgradable.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use starknet::class_hash::ClassHash;
use starknet::{ContractAddress};

#[starknet::interface]
pub trait IUpgradeableContract<TContractState> {
fn upgrade(ref self: TContractState, impl_hash: ClassHash);
fn version(self: @TContractState) -> u8;
fn upgradeWithAuth(ref self: TContractState, impl_hash: ClassHash);
fn setAdmin(ref self: TContractState, new_admin: ContractAddress);
fn getAdmin(self: @TContractState) -> ContractAddress;
}

#[starknet::contract]
pub mod UpgradeableContract_V0 {
use super::IUpgradeableContract;
use starknet::class_hash::ClassHash;
use core::num::traits::Zero;
use starknet::{ContractAddress, get_caller_address};
use core::starknet::storage::{
StoragePointerReadAccess, StoragePointerWriteAccess,
};

#[storage]
struct Storage {
admin: ContractAddress,
}

#[event]
#[derive(Drop, starknet::Event)]
pub enum Event {
Upgraded: Upgraded,
AdminChanged: AdminChanged,
}

#[derive(Drop, starknet::Event)]
pub struct Upgraded {
pub implementation: ClassHash,
}

#[derive(Drop, starknet::Event)]
pub struct AdminChanged {
pub new_admin: ContractAddress,
}

#[abi(embed_v0)]
impl UpgradeableContract of IUpgradeableContract<ContractState> {
fn upgrade(ref self: ContractState, impl_hash: ClassHash) {
assert(impl_hash.is_non_zero(), 'Class hash cannot be zero');
starknet::syscalls::replace_class_syscall(impl_hash).unwrap();
self.emit(Event::Upgraded(Upgraded { implementation: impl_hash }));
}

fn version(self: @ContractState) -> u8 {
0
}

fn upgradeWithAuth(ref self: ContractState, impl_hash: ClassHash) {
let caller = get_caller_address();
assert(caller == self.admin.read(), 'Only admin can upgrade');
starknet::syscalls::replace_class_syscall(impl_hash).unwrap();
}

fn setAdmin(ref self: ContractState, new_admin: ContractAddress) {
self.admin.write(new_admin);
self.emit(Event::AdminChanged(AdminChanged { new_admin }));
}

fn getAdmin(self: @ContractState) -> ContractAddress {
self.admin.read()
}
}
}

#[starknet::contract]
pub mod UpgradeableContract_V1 {
use super::IUpgradeableContract;
use starknet::class_hash::ClassHash;
use core::num::traits::Zero;
use starknet::{ContractAddress, get_caller_address};
use core::starknet::storage::{
StoragePointerReadAccess, StoragePointerWriteAccess,
};

#[storage]
struct Storage {
admin: ContractAddress,
}

#[event]
#[derive(Drop, starknet::Event)]
enum Event {
Upgraded: Upgraded,
AdminChanged: AdminChanged,
}

#[derive(Drop, starknet::Event)]
struct Upgraded {
implementation: ClassHash,
}

#[derive(Drop, starknet::Event)]
struct AdminChanged {
new_admin: ContractAddress,
}

#[abi(embed_v0)]
impl UpgradeableContract of IUpgradeableContract<ContractState> {
fn upgrade(ref self: ContractState, impl_hash: ClassHash) {
assert(impl_hash.is_non_zero(), 'Class hash cannot be zero');
starknet::syscalls::replace_class_syscall(impl_hash).unwrap();
self.emit(Event::Upgraded(Upgraded { implementation: impl_hash }))
}

fn version(self: @ContractState) -> u8 {
1
}

fn upgradeWithAuth(ref self: ContractState, impl_hash: ClassHash) {
let caller = get_caller_address();
assert(caller == self.admin.read(), 'Only admin can upgrade');
starknet::syscalls::replace_class_syscall(impl_hash).unwrap();
}

fn setAdmin(ref self: ContractState, new_admin: ContractAddress) {
self.admin.write(new_admin);
self.emit(Event::AdminChanged(AdminChanged { new_admin }));
}

fn getAdmin(self: @ContractState) -> ContractAddress {
self.admin.read()
}
}
}

2 changes: 1 addition & 1 deletion contracts/tests/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod test_starkfinder;
mod test_starkidentity;
mod test_nft_dutch;
mod test_constant_product_amm;
mod test_upgradable_contract;
#[feature("safe_dispatcher")]
mod test_timelock;


69 changes: 69 additions & 0 deletions contracts/tests/test_upgradable_contract.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
mod tests {
use starknet::class_hash::ClassHash;

use starknet::ContractAddress;
use snforge_std::{
declare, ContractClassTrait, DeclareResultTrait
};

use contracts::upgradable::{
UpgradeableContract_V0, IUpgradeableContractDispatcher,
IUpgradeableContractDispatcherTrait, UpgradeableContract_V1,
};

use core::num::traits::Zero;


fn deploy_v0() -> (IUpgradeableContractDispatcher, ContractAddress, ClassHash) {
// First declare the contract
let contract = declare("UpgradeableContract_V0").unwrap();
let contract_class = contract.contract_class();

let (contract_address, _) = contract_class.deploy(@array![]).unwrap();

(
IUpgradeableContractDispatcher { contract_address },
contract_address,
UpgradeableContract_V0::TEST_CLASS_HASH.try_into().unwrap(),
)
}

// deploy v1 contract
fn deploy_v1() -> (IUpgradeableContractDispatcher, ContractAddress, ClassHash) {
// First declare the contract
let contract = declare("UpgradeableContract_V1").unwrap();
let contract_class = contract.contract_class();

let (contract_address, _) = contract_class.deploy(@array![]).unwrap();
(
IUpgradeableContractDispatcher { contract_address },
contract_address,
UpgradeableContract_V1::TEST_CLASS_HASH.try_into().unwrap()
)
}


#[test]
fn test_deploy_v0() {
deploy_v0();
}

#[test]
fn test_deploy_v1() {
deploy_v1();
}

#[test]
fn test_version_from_v0() {
let (dispatcher, _, _) = deploy_v0();
assert(dispatcher.version() == 0, 'incorrect version');
}

#[test]
#[should_panic(expected: 'Class hash cannot be zero')]
fn test_upgrade_when_classhash_is_zero() {
let (dispatcher_v0, _, _) = deploy_v0();
dispatcher_v0.upgrade(Zero::zero());
}
}

0 comments on commit c888e09

Please sign in to comment.