Skip to content

Commit a3b58d5

Browse files
clean up
1 parent 88eb46b commit a3b58d5

File tree

2 files changed

+115
-171
lines changed

2 files changed

+115
-171
lines changed

Diff for: src/keypair.rs

+48-155
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use argon2::{
55
};
66
use bip39::{Language, Mnemonic};
77
use rand::RngCore;
8-
use ring::aead::{self, Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
8+
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
99
use schnorrkel::{
1010
derive::{ChainCode, Derivation},
1111
ExpansionMode, MiniSecretKey,
@@ -35,6 +35,8 @@ pub enum KeyFileError {
3535
Io(#[from] std::io::Error),
3636
#[error("Keyfile at: {0} not found")]
3737
NotFound(String),
38+
#[error("Invalid key type: {0}")]
39+
InvalidKeyType(String),
3840
}
3941
#[derive(Debug)]
4042
pub struct Keypair {
@@ -189,6 +191,10 @@ fn coldkey_pub_file(path: &str, name: &str) -> PathBuf {
189191
wallet_path.join("coldkeypub.txt")
190192
}
191193

194+
fn coldkey_file(path: &str, name: &str) -> PathBuf {
195+
let wallet_path = PathBuf::from(shellexpand::tilde(path).into_owned()).join(name);
196+
wallet_path.join("coldkey")
197+
}
192198
/// Writes keyfile data to a file with specific permissions.
193199
///
194200
/// This function writes the provided keyfile data to a file at the specified path.
@@ -248,6 +254,22 @@ pub fn write_keyfile_data_to_file(
248254
Ok(())
249255
}
250256

257+
pub fn load_keypair_dict(
258+
name: &str,
259+
key_type: &str,
260+
password: Option<&str>,
261+
) -> Result<Keypair, Box<dyn std::error::Error>> {
262+
// Load and deserialize the keyfile data
263+
let keyfile_data = get_keypair_from_file(name, key_type)?;
264+
let keypair = if let Some(pass) = password {
265+
let decrypted_data = decrypt_keyfile_data(&keyfile_data, pass)?;
266+
deserialize_keyfile_data_to_keypair(&decrypted_data)?
267+
} else {
268+
deserialize_keyfile_data_to_keypair(&keyfile_data)?
269+
};
270+
271+
Ok(keypair)
272+
}
251273
/// Loads a hotkey pair from a keyfile.
252274
///
253275
/// This function retrieves the private key data from a keyfile, processes it,
@@ -269,18 +291,13 @@ pub fn write_keyfile_data_to_file(
269291
/// - The private key is missing from the keyfile data.
270292
/// - The decoded private key has an invalid length.
271293
/// - The sr25519::Pair cannot be created from the seed.
272-
pub fn load_hotkey_pair(
273-
hotkey_name: &str,
294+
pub fn load_keypair(
295+
name: &str,
296+
key_type: &str,
274297
password: Option<&str>,
275298
) -> Result<sr25519::Pair, Box<dyn std::error::Error>> {
276299
// Load and deserialize the keyfile data
277-
let keyfile_data = load_keyfile_data_from_file(hotkey_name)?;
278-
let keypair = if let Some(pass) = password {
279-
let decrypted_data = decrypt_keyfile_data(&keyfile_data, pass)?;
280-
deserialize_keyfile_data_to_keypair(&decrypted_data)?
281-
} else {
282-
deserialize_keyfile_data_to_keypair(&keyfile_data)?
283-
};
300+
let keypair = load_keypair_dict(name, key_type, password)?;
284301

285302
// Extract the private key
286303
let private_key = keypair
@@ -304,7 +321,6 @@ pub fn load_hotkey_pair(
304321
return Err("Invalid seed length".into());
305322
}
306323

307-
// Create and return the sr25519::Pair
308324
let pair = sr25519::Pair::from_seed_slice(&seed)?;
309325
Ok(pair)
310326
}
@@ -328,9 +344,14 @@ pub fn load_hotkey_pair(
328344
/// This function will return an error if:
329345
/// - The keyfile does not exist at the expected location.
330346
/// - There are issues opening or reading the file.
331-
pub fn load_keyfile_data_from_file(name: &str) -> Result<Vec<u8>, KeyFileError> {
347+
pub fn get_keypair_from_file(name: &str, key_type: &str) -> Result<Vec<u8>, KeyFileError> {
332348
let default_path = BT_WALLET_PATH;
333-
let path = hotkey_file(default_path, name);
349+
let path = match key_type {
350+
"hotkey" => hotkey_file(default_path, name),
351+
"coldkeypub" => coldkey_pub_file(default_path, name),
352+
"coldkey" => coldkey_file(default_path, name),
353+
_ => return Err(KeyFileError::InvalidKeyType(key_type.to_string())),
354+
};
334355

335356
if !exists_on_device(&path) {
336357
return Err(KeyFileError::NotFound(path.to_string_lossy().into_owned()));
@@ -487,8 +508,8 @@ pub fn save_keypair(
487508
mnemonic: Mnemonic,
488509
seed: [u8; 32],
489510
name: &str,
490-
encrypt: bool,
491511
key_type: &str,
512+
password: Option<String>,
492513
) -> Keypair {
493514
let keypair = Keypair {
494515
public_key: Some(key_pair.public().to_vec()),
@@ -501,17 +522,21 @@ pub fn save_keypair(
501522
let key_path;
502523
if key_type == "hotkey" {
503524
key_path = hotkey_file(path, name);
504-
} else {
525+
} else if key_type == "coldkeypub" {
505526
key_path = coldkey_pub_file(path, name);
527+
} else if key_type == "coldkey" {
528+
key_path = coldkey_file(path, name);
529+
} else {
530+
panic!("Invalid key type: {}", key_type);
506531
}
507532
// Ensure the directory exists before writing the file
508533
if let Some(parent) = key_path.parent() {
509534
std::fs::create_dir_all(parent).expect("Failed to create directory");
510535
}
511-
let password = "ben+is+a+css+pro";
512-
if encrypt {
536+
let password = password.unwrap_or_else(|| "".to_string());
537+
if !password.is_empty() {
513538
let encrypted_data =
514-
encrypt_keyfile_data(serialized_keypair_to_keyfile_data(&keypair), password)
539+
encrypt_keyfile_data(serialized_keypair_to_keyfile_data(&keypair), &password)
515540
.expect("Failed to encrypt keyfile");
516541
write_keyfile_data_to_file(&key_path, encrypted_data, false)
517542
.expect("Failed to write encrypted keyfile");
@@ -551,53 +576,17 @@ pub fn save_keypair(
551576
/// - It fails to derive the sr25519 key.
552577
/// - It fails to create the directory for the keyfile.
553578
/// - It fails to write the keyfile.
554-
pub fn create_hotkey(mnemonic: Mnemonic, name: &str) -> (sr25519::Pair, [u8; 32]) {
555-
let seed: [u8; 32] = mnemonic.to_seed("")[..32]
556-
.try_into()
557-
.expect("Failed to create seed");
558-
559-
let derivation_path: Vec<u8> = format!("//{}", name).into_bytes();
560-
561-
let hotkey_pair: sr25519::Pair =
562-
derive_sr25519_key(&seed, &derivation_path).expect("Failed to derive sr25519 key");
563-
564-
(hotkey_pair, seed) //hack to demo hotkey_pair sign
565-
}
566-
567-
/// Creates a new coldkey pair and returns it along with its seed.
568-
///
569-
/// This function performs the following steps:
570-
/// 1. Generates a seed from the provided mnemonic.
571-
/// 2. Creates a derivation path using the provided name.
572-
/// 3. Derives an sr25519 key pair using the seed and derivation path.
573-
///
574-
/// # Arguments
575-
///
576-
/// * `mnemonic` - A `Mnemonic` object representing the seed phrase.
577-
/// * `name` - A string slice containing the name for the coldkey.
578-
///
579-
/// # Returns
580-
///
581-
/// Returns a tuple containing:
582-
/// - An `sr25519::Pair` representing the generated coldkey pair.
583-
/// - A 32-byte array containing the seed used to generate the key pair.
584-
///
585-
/// # Panics
586-
///
587-
/// This function will panic if:
588-
/// - It fails to create a seed from the mnemonic.
589-
/// - It fails to derive the sr25519 key.
590-
pub fn create_coldkey(mnemonic: Mnemonic, name: &str) -> (sr25519::Pair, [u8; 32]) {
579+
pub fn create_keypair(mnemonic: Mnemonic, name: &str) -> (sr25519::Pair, [u8; 32]) {
591580
let seed: [u8; 32] = mnemonic.to_seed("")[..32]
592581
.try_into()
593582
.expect("Failed to create seed");
594583

595584
let derivation_path: Vec<u8> = format!("//{}", name).into_bytes();
596585

597-
let coldkey_pair: sr25519::Pair =
586+
let keypair: sr25519::Pair =
598587
derive_sr25519_key(&seed, &derivation_path).expect("Failed to derive sr25519 key");
599588

600-
(coldkey_pair, seed)
589+
(keypair, seed)
601590
}
602591

603592
fn generate_nonce() -> [u8; NONCE_SIZE] {
@@ -638,7 +627,7 @@ fn decrypt_keyfile_data(encrypted_data: &[u8], password: &str) -> Result<Vec<u8>
638627
let argon2 = Argon2::default();
639628

640629
// Create a SaltString from our constant salt
641-
let salt = SaltString::b64_encode(NACL_SALT)?;
630+
let salt = SaltString::encode_b64(NACL_SALT)?;
642631

643632
// Hash the password to derive the key
644633
let password_hash = argon2.hash_password(password, &salt)?;
@@ -707,7 +696,7 @@ fn encrypt_keyfile_data(keyfile_data: Vec<u8>, password: &str) -> Result<Vec<u8>
707696
let argon2 = Argon2::default();
708697

709698
// Create a SaltString from our constant salt
710-
let salt = SaltString::b64_encode(NACL_SALT)?;
699+
let salt = SaltString::encode_b64(NACL_SALT)?;
711700

712701
// Hash the password to derive the key
713702
let password_hash = argon2.hash_password(password, &salt)?;
@@ -740,99 +729,3 @@ fn encrypt_keyfile_data(keyfile_data: Vec<u8>, password: &str) -> Result<Vec<u8>
740729

741730
Ok(result)
742731
}
743-
744-
#[cfg(test)]
745-
mod tests {
746-
use super::*;
747-
use bip39::Language;
748-
use rand::Rng;
749-
750-
#[test]
751-
fn test_create_mnemonic_valid_word_counts() {
752-
let valid_word_counts = [12, 15, 18, 21, 24];
753-
for &word_count in &valid_word_counts {
754-
let result = create_mnemonic(word_count);
755-
assert!(
756-
result.is_ok(),
757-
"Failed to create mnemonic with {} words",
758-
word_count
759-
);
760-
let mnemonic = result.unwrap();
761-
assert_eq!(
762-
mnemonic.word_count(),
763-
word_count as usize,
764-
"Mnemonic word count doesn't match expected"
765-
);
766-
}
767-
}
768-
769-
#[test]
770-
fn test_mnemonic_uniqueness() {
771-
let mnemonic1 = create_mnemonic(12).unwrap();
772-
let mnemonic2 = create_mnemonic(12).unwrap();
773-
assert_ne!(
774-
mnemonic1.to_string(),
775-
mnemonic2.to_string(),
776-
"Two generated mnemonics should not be identical"
777-
);
778-
}
779-
780-
#[test]
781-
fn test_mnemonic_language() {
782-
let mnemonic = create_mnemonic(12).unwrap();
783-
assert_eq!(
784-
mnemonic.language(),
785-
Language::English,
786-
"Mnemonic should be in English"
787-
);
788-
}
789-
790-
#[test]
791-
fn test_derive_sr25519_key_valid_input() {
792-
let mut rng = rand::thread_rng();
793-
let seed: [u8; 32] = rng.gen();
794-
let path = b"/some/path";
795-
796-
let result = derive_sr25519_key(&seed, path);
797-
assert!(result.is_ok());
798-
}
799-
800-
#[test]
801-
fn test_derive_sr25519_key_invalid_seed_length() {
802-
let seed = [0u8; 16]; // Invalid length
803-
let path = b"/some/path";
804-
805-
let result = derive_sr25519_key(&seed, path);
806-
assert!(result.is_err());
807-
let err = result.err().unwrap();
808-
assert!(err
809-
.to_string()
810-
.contains("Invalid seed length: expected 32, got 16"));
811-
}
812-
813-
#[test]
814-
fn test_derive_sr25519_key_empty_path() {
815-
let mut rng = rand::thread_rng();
816-
let seed: [u8; 32] = rng.gen();
817-
let path = b"";
818-
819-
let result = derive_sr25519_key(&seed, path);
820-
assert!(result.is_ok());
821-
}
822-
823-
#[test]
824-
fn test_derive_sr25519_key_deterministic() {
825-
let seed: [u8; 32] = [42u8; 32];
826-
let path = b"/test/path";
827-
828-
let result1 = derive_sr25519_key(&seed, path);
829-
let result2 = derive_sr25519_key(&seed, path);
830-
831-
assert!(result1.is_ok() && result2.is_ok());
832-
assert_eq!(
833-
result1.unwrap().public(),
834-
result2.unwrap().public(),
835-
"Derived keys should be identical for the same seed and path"
836-
);
837-
}
838-
}

0 commit comments

Comments
 (0)