@@ -5,7 +5,7 @@ use argon2::{
5
5
} ;
6
6
use bip39:: { Language , Mnemonic } ;
7
7
use rand:: RngCore ;
8
- use ring:: aead:: { self , Aad , LessSafeKey , Nonce , UnboundKey , CHACHA20_POLY1305 } ;
8
+ use ring:: aead:: { Aad , LessSafeKey , Nonce , UnboundKey , CHACHA20_POLY1305 } ;
9
9
use schnorrkel:: {
10
10
derive:: { ChainCode , Derivation } ,
11
11
ExpansionMode , MiniSecretKey ,
@@ -35,6 +35,8 @@ pub enum KeyFileError {
35
35
Io ( #[ from] std:: io:: Error ) ,
36
36
#[ error( "Keyfile at: {0} not found" ) ]
37
37
NotFound ( String ) ,
38
+ #[ error( "Invalid key type: {0}" ) ]
39
+ InvalidKeyType ( String ) ,
38
40
}
39
41
#[ derive( Debug ) ]
40
42
pub struct Keypair {
@@ -189,6 +191,10 @@ fn coldkey_pub_file(path: &str, name: &str) -> PathBuf {
189
191
wallet_path. join ( "coldkeypub.txt" )
190
192
}
191
193
194
+ fn coldkey_file ( path : & str , name : & str ) -> PathBuf {
195
+ let wallet_path = PathBuf :: from ( shellexpand:: tilde ( path) . into_owned ( ) ) . join ( name) ;
196
+ wallet_path. join ( "coldkey" )
197
+ }
192
198
/// Writes keyfile data to a file with specific permissions.
193
199
///
194
200
/// This function writes the provided keyfile data to a file at the specified path.
@@ -248,6 +254,22 @@ pub fn write_keyfile_data_to_file(
248
254
Ok ( ( ) )
249
255
}
250
256
257
+ pub fn load_keypair_dict (
258
+ name : & str ,
259
+ key_type : & str ,
260
+ password : Option < & str > ,
261
+ ) -> Result < Keypair , Box < dyn std:: error:: Error > > {
262
+ // Load and deserialize the keyfile data
263
+ let keyfile_data = get_keypair_from_file ( name, key_type) ?;
264
+ let keypair = if let Some ( pass) = password {
265
+ let decrypted_data = decrypt_keyfile_data ( & keyfile_data, pass) ?;
266
+ deserialize_keyfile_data_to_keypair ( & decrypted_data) ?
267
+ } else {
268
+ deserialize_keyfile_data_to_keypair ( & keyfile_data) ?
269
+ } ;
270
+
271
+ Ok ( keypair)
272
+ }
251
273
/// Loads a hotkey pair from a keyfile.
252
274
///
253
275
/// This function retrieves the private key data from a keyfile, processes it,
@@ -269,18 +291,13 @@ pub fn write_keyfile_data_to_file(
269
291
/// - The private key is missing from the keyfile data.
270
292
/// - The decoded private key has an invalid length.
271
293
/// - The sr25519::Pair cannot be created from the seed.
272
- pub fn load_hotkey_pair (
273
- hotkey_name : & str ,
294
+ pub fn load_keypair (
295
+ name : & str ,
296
+ key_type : & str ,
274
297
password : Option < & str > ,
275
298
) -> Result < sr25519:: Pair , Box < dyn std:: error:: Error > > {
276
299
// Load and deserialize the keyfile data
277
- let keyfile_data = load_keyfile_data_from_file ( hotkey_name) ?;
278
- let keypair = if let Some ( pass) = password {
279
- let decrypted_data = decrypt_keyfile_data ( & keyfile_data, pass) ?;
280
- deserialize_keyfile_data_to_keypair ( & decrypted_data) ?
281
- } else {
282
- deserialize_keyfile_data_to_keypair ( & keyfile_data) ?
283
- } ;
300
+ let keypair = load_keypair_dict ( name, key_type, password) ?;
284
301
285
302
// Extract the private key
286
303
let private_key = keypair
@@ -304,7 +321,6 @@ pub fn load_hotkey_pair(
304
321
return Err ( "Invalid seed length" . into ( ) ) ;
305
322
}
306
323
307
- // Create and return the sr25519::Pair
308
324
let pair = sr25519:: Pair :: from_seed_slice ( & seed) ?;
309
325
Ok ( pair)
310
326
}
@@ -328,9 +344,14 @@ pub fn load_hotkey_pair(
328
344
/// This function will return an error if:
329
345
/// - The keyfile does not exist at the expected location.
330
346
/// - There are issues opening or reading the file.
331
- pub fn load_keyfile_data_from_file ( name : & str ) -> Result < Vec < u8 > , KeyFileError > {
347
+ pub fn get_keypair_from_file ( name : & str , key_type : & str ) -> Result < Vec < u8 > , KeyFileError > {
332
348
let default_path = BT_WALLET_PATH ;
333
- let path = hotkey_file ( default_path, name) ;
349
+ let path = match key_type {
350
+ "hotkey" => hotkey_file ( default_path, name) ,
351
+ "coldkeypub" => coldkey_pub_file ( default_path, name) ,
352
+ "coldkey" => coldkey_file ( default_path, name) ,
353
+ _ => return Err ( KeyFileError :: InvalidKeyType ( key_type. to_string ( ) ) ) ,
354
+ } ;
334
355
335
356
if !exists_on_device ( & path) {
336
357
return Err ( KeyFileError :: NotFound ( path. to_string_lossy ( ) . into_owned ( ) ) ) ;
@@ -487,8 +508,8 @@ pub fn save_keypair(
487
508
mnemonic : Mnemonic ,
488
509
seed : [ u8 ; 32 ] ,
489
510
name : & str ,
490
- encrypt : bool ,
491
511
key_type : & str ,
512
+ password : Option < String > ,
492
513
) -> Keypair {
493
514
let keypair = Keypair {
494
515
public_key : Some ( key_pair. public ( ) . to_vec ( ) ) ,
@@ -501,17 +522,21 @@ pub fn save_keypair(
501
522
let key_path;
502
523
if key_type == "hotkey" {
503
524
key_path = hotkey_file ( path, name) ;
504
- } else {
525
+ } else if key_type == "coldkeypub" {
505
526
key_path = coldkey_pub_file ( path, name) ;
527
+ } else if key_type == "coldkey" {
528
+ key_path = coldkey_file ( path, name) ;
529
+ } else {
530
+ panic ! ( "Invalid key type: {}" , key_type) ;
506
531
}
507
532
// Ensure the directory exists before writing the file
508
533
if let Some ( parent) = key_path. parent ( ) {
509
534
std:: fs:: create_dir_all ( parent) . expect ( "Failed to create directory" ) ;
510
535
}
511
- let password = "ben+is+a+css+pro" ;
512
- if encrypt {
536
+ let password = password . unwrap_or_else ( || "" . to_string ( ) ) ;
537
+ if !password . is_empty ( ) {
513
538
let encrypted_data =
514
- encrypt_keyfile_data ( serialized_keypair_to_keyfile_data ( & keypair) , password)
539
+ encrypt_keyfile_data ( serialized_keypair_to_keyfile_data ( & keypair) , & password)
515
540
. expect ( "Failed to encrypt keyfile" ) ;
516
541
write_keyfile_data_to_file ( & key_path, encrypted_data, false )
517
542
. expect ( "Failed to write encrypted keyfile" ) ;
@@ -551,53 +576,17 @@ pub fn save_keypair(
551
576
/// - It fails to derive the sr25519 key.
552
577
/// - It fails to create the directory for the keyfile.
553
578
/// - It fails to write the keyfile.
554
- pub fn create_hotkey ( mnemonic : Mnemonic , name : & str ) -> ( sr25519:: Pair , [ u8 ; 32 ] ) {
555
- let seed: [ u8 ; 32 ] = mnemonic. to_seed ( "" ) [ ..32 ]
556
- . try_into ( )
557
- . expect ( "Failed to create seed" ) ;
558
-
559
- let derivation_path: Vec < u8 > = format ! ( "//{}" , name) . into_bytes ( ) ;
560
-
561
- let hotkey_pair: sr25519:: Pair =
562
- derive_sr25519_key ( & seed, & derivation_path) . expect ( "Failed to derive sr25519 key" ) ;
563
-
564
- ( hotkey_pair, seed) //hack to demo hotkey_pair sign
565
- }
566
-
567
- /// Creates a new coldkey pair and returns it along with its seed.
568
- ///
569
- /// This function performs the following steps:
570
- /// 1. Generates a seed from the provided mnemonic.
571
- /// 2. Creates a derivation path using the provided name.
572
- /// 3. Derives an sr25519 key pair using the seed and derivation path.
573
- ///
574
- /// # Arguments
575
- ///
576
- /// * `mnemonic` - A `Mnemonic` object representing the seed phrase.
577
- /// * `name` - A string slice containing the name for the coldkey.
578
- ///
579
- /// # Returns
580
- ///
581
- /// Returns a tuple containing:
582
- /// - An `sr25519::Pair` representing the generated coldkey pair.
583
- /// - A 32-byte array containing the seed used to generate the key pair.
584
- ///
585
- /// # Panics
586
- ///
587
- /// This function will panic if:
588
- /// - It fails to create a seed from the mnemonic.
589
- /// - It fails to derive the sr25519 key.
590
- pub fn create_coldkey ( mnemonic : Mnemonic , name : & str ) -> ( sr25519:: Pair , [ u8 ; 32 ] ) {
579
+ pub fn create_keypair ( mnemonic : Mnemonic , name : & str ) -> ( sr25519:: Pair , [ u8 ; 32 ] ) {
591
580
let seed: [ u8 ; 32 ] = mnemonic. to_seed ( "" ) [ ..32 ]
592
581
. try_into ( )
593
582
. expect ( "Failed to create seed" ) ;
594
583
595
584
let derivation_path: Vec < u8 > = format ! ( "//{}" , name) . into_bytes ( ) ;
596
585
597
- let coldkey_pair : sr25519:: Pair =
586
+ let keypair : sr25519:: Pair =
598
587
derive_sr25519_key ( & seed, & derivation_path) . expect ( "Failed to derive sr25519 key" ) ;
599
588
600
- ( coldkey_pair , seed)
589
+ ( keypair , seed)
601
590
}
602
591
603
592
fn generate_nonce ( ) -> [ u8 ; NONCE_SIZE ] {
@@ -638,7 +627,7 @@ fn decrypt_keyfile_data(encrypted_data: &[u8], password: &str) -> Result<Vec<u8>
638
627
let argon2 = Argon2 :: default ( ) ;
639
628
640
629
// Create a SaltString from our constant salt
641
- let salt = SaltString :: b64_encode ( NACL_SALT ) ?;
630
+ let salt = SaltString :: encode_b64 ( NACL_SALT ) ?;
642
631
643
632
// Hash the password to derive the key
644
633
let password_hash = argon2. hash_password ( password, & salt) ?;
@@ -707,7 +696,7 @@ fn encrypt_keyfile_data(keyfile_data: Vec<u8>, password: &str) -> Result<Vec<u8>
707
696
let argon2 = Argon2 :: default ( ) ;
708
697
709
698
// Create a SaltString from our constant salt
710
- let salt = SaltString :: b64_encode ( NACL_SALT ) ?;
699
+ let salt = SaltString :: encode_b64 ( NACL_SALT ) ?;
711
700
712
701
// Hash the password to derive the key
713
702
let password_hash = argon2. hash_password ( password, & salt) ?;
@@ -740,99 +729,3 @@ fn encrypt_keyfile_data(keyfile_data: Vec<u8>, password: &str) -> Result<Vec<u8>
740
729
741
730
Ok ( result)
742
731
}
743
-
744
- #[ cfg( test) ]
745
- mod tests {
746
- use super :: * ;
747
- use bip39:: Language ;
748
- use rand:: Rng ;
749
-
750
- #[ test]
751
- fn test_create_mnemonic_valid_word_counts ( ) {
752
- let valid_word_counts = [ 12 , 15 , 18 , 21 , 24 ] ;
753
- for & word_count in & valid_word_counts {
754
- let result = create_mnemonic ( word_count) ;
755
- assert ! (
756
- result. is_ok( ) ,
757
- "Failed to create mnemonic with {} words" ,
758
- word_count
759
- ) ;
760
- let mnemonic = result. unwrap ( ) ;
761
- assert_eq ! (
762
- mnemonic. word_count( ) ,
763
- word_count as usize ,
764
- "Mnemonic word count doesn't match expected"
765
- ) ;
766
- }
767
- }
768
-
769
- #[ test]
770
- fn test_mnemonic_uniqueness ( ) {
771
- let mnemonic1 = create_mnemonic ( 12 ) . unwrap ( ) ;
772
- let mnemonic2 = create_mnemonic ( 12 ) . unwrap ( ) ;
773
- assert_ne ! (
774
- mnemonic1. to_string( ) ,
775
- mnemonic2. to_string( ) ,
776
- "Two generated mnemonics should not be identical"
777
- ) ;
778
- }
779
-
780
- #[ test]
781
- fn test_mnemonic_language ( ) {
782
- let mnemonic = create_mnemonic ( 12 ) . unwrap ( ) ;
783
- assert_eq ! (
784
- mnemonic. language( ) ,
785
- Language :: English ,
786
- "Mnemonic should be in English"
787
- ) ;
788
- }
789
-
790
- #[ test]
791
- fn test_derive_sr25519_key_valid_input ( ) {
792
- let mut rng = rand:: thread_rng ( ) ;
793
- let seed: [ u8 ; 32 ] = rng. gen ( ) ;
794
- let path = b"/some/path" ;
795
-
796
- let result = derive_sr25519_key ( & seed, path) ;
797
- assert ! ( result. is_ok( ) ) ;
798
- }
799
-
800
- #[ test]
801
- fn test_derive_sr25519_key_invalid_seed_length ( ) {
802
- let seed = [ 0u8 ; 16 ] ; // Invalid length
803
- let path = b"/some/path" ;
804
-
805
- let result = derive_sr25519_key ( & seed, path) ;
806
- assert ! ( result. is_err( ) ) ;
807
- let err = result. err ( ) . unwrap ( ) ;
808
- assert ! ( err
809
- . to_string( )
810
- . contains( "Invalid seed length: expected 32, got 16" ) ) ;
811
- }
812
-
813
- #[ test]
814
- fn test_derive_sr25519_key_empty_path ( ) {
815
- let mut rng = rand:: thread_rng ( ) ;
816
- let seed: [ u8 ; 32 ] = rng. gen ( ) ;
817
- let path = b"" ;
818
-
819
- let result = derive_sr25519_key ( & seed, path) ;
820
- assert ! ( result. is_ok( ) ) ;
821
- }
822
-
823
- #[ test]
824
- fn test_derive_sr25519_key_deterministic ( ) {
825
- let seed: [ u8 ; 32 ] = [ 42u8 ; 32 ] ;
826
- let path = b"/test/path" ;
827
-
828
- let result1 = derive_sr25519_key ( & seed, path) ;
829
- let result2 = derive_sr25519_key ( & seed, path) ;
830
-
831
- assert ! ( result1. is_ok( ) && result2. is_ok( ) ) ;
832
- assert_eq ! (
833
- result1. unwrap( ) . public( ) ,
834
- result2. unwrap( ) . public( ) ,
835
- "Derived keys should be identical for the same seed and path"
836
- ) ;
837
- }
838
- }
0 commit comments