From d567663953a0827cf018c316b1a58959e6a695b2 Mon Sep 17 00:00:00 2001 From: William Chong Date: Thu, 7 Dec 2023 17:00:52 +0400 Subject: [PATCH 1/4] Allow specifying certificate file names - Add tests for `create-certs` command - Improve test coverage - Some refactoring - Fix linting issues --- .gitignore | 2 + README.md | 2 + certificates/boring_linux.go | 2 +- certificates/certificates.go | 60 ----- certificates/common.go | 49 +++-- certificates/common_test.go | 66 +++--- certificates/create_ca.go | 95 ++++---- certificates/create_ca_test.go | 199 ++++++++++++----- certificates/create_certs.go | 113 ++++------ certificates/create_certs_test.go | 354 ++++++++++++++++++++++++++++++ certificates/create_node.go | 132 +++++------ certificates/create_node_test.go | 212 +++++++++++------- main.go | 32 +-- references/certs.yml | 23 ++ references/named_certs.yml | 28 +++ 15 files changed, 921 insertions(+), 448 deletions(-) delete mode 100644 certificates/certificates.go create mode 100644 certificates/create_certs_test.go create mode 100644 references/certs.yml create mode 100644 references/named_certs.yml diff --git a/.gitignore b/.gitignore index 2630087..fc4c0b4 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ ca/ es-gencert-cli certs.yml .DS_Store +*.crt +*.key diff --git a/README.md b/README.md index fac3410..7202be1 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,8 @@ certificates: dns-names: "localhost,eventstore-node2.localhost.com" ``` +If you want to specify the name of the certificates from the config file, you can add the name field to the certificate definition. You can see an example of this in the [example configuration](references/named_certs.yml). + ## Development Building or working on `es-gencert-cli` requires a Go environment, version 1.14 or higher. diff --git a/certificates/boring_linux.go b/certificates/boring_linux.go index 3f17c1b..ab8cd55 100644 --- a/certificates/boring_linux.go +++ b/certificates/boring_linux.go @@ -8,4 +8,4 @@ import ( func isBoringEnabled() bool { return boring.Enabled() -} \ No newline at end of file +} diff --git a/certificates/certificates.go b/certificates/certificates.go deleted file mode 100644 index c91262e..0000000 --- a/certificates/certificates.go +++ /dev/null @@ -1,60 +0,0 @@ -package certificates - -import ( - "log" - "os" - "strings" - - "github.com/mitchellh/cli" -) - -type Certificates struct { - Ui cli.Ui -} - -func (command *Certificates) Run(args []string) int { - ui := &cli.BasicUi{ - Reader: os.Stdin, - Writer: os.Stdout, - ErrorWriter: os.Stderr, - } - c := cli.NewCLI("Event Store CLI certificates", "") - c.Args = args - c.Commands = map[string]cli.CommandFactory{ - "create-ca": func() (cli.Command, error) { - return &CreateCA{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil - }, - "create-node": func() (cli.Command, error) { - return &CreateNode{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil - }, - } - exitStatus, err := c.Run() - if err != nil { - log.Println(err) - } - return exitStatus -} - -func (c *Certificates) Help() string { - helpText := ` -usage: certificates [--help] [] - -Available commands: -` - helpText += c.Synopsis() - return strings.TrimSpace(helpText) -} - -func (c *Certificates) Synopsis() string { - return "certificates (create_ca, create_node)" -} diff --git a/certificates/common.go b/certificates/common.go index 1c5d498..625f22a 100644 --- a/certificates/common.go +++ b/certificates/common.go @@ -11,16 +11,18 @@ import ( "math/big" "os" "path" - "text/tabwriter" + "path/filepath" ) -const defaultKeySize = 2048 - -const forceOption = "Force overwrite of existing files without prompting" - const ( - ErrFileExists = "Error: Existing files would be overwritten. Use -force to proceed" + ForceFlagUsage = "Force overwrite of existing files without prompting" + NameFlagUsage = "The name of the CA certificate and key file" + OutDirFlagUsage = "The output directory" + DayFlagUsage = "the validity period of the certificate in days" + CaKeyFlagUsage = "the path to the CA key file" + CaPathFlagUsage = "the path to the CA certificate file" ) +const defaultKeySize = 2048 func generateSerialNumber(bits uint) (*big.Int, error) { maxValue := new(big.Int).Lsh(big.NewInt(1), bits) @@ -48,13 +50,9 @@ func writeFileWithDir(filePath string, data []byte, perm os.FileMode) error { return os.WriteFile(filePath, data, perm) } -func writeHelpOption(w *tabwriter.Writer, title string, description string) { - fmt.Fprintf(w, "\t-%s\t%s\n", title, description) -} - func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem *bytes.Buffer, force bool) error { - certFile := path.Join(outputDir, fileName+".crt") - keyFile := path.Join(outputDir, fileName+".key") + certFile := filepath.Join(outputDir, fileName+".crt") + keyFile := filepath.Join(outputDir, fileName+".key") if force { if _, err := os.Stat(certFile); err == nil { @@ -71,24 +69,17 @@ func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem * err := writeFileWithDir(certFile, certPem.Bytes(), 0444) if err != nil { - return fmt.Errorf("error writing certificate to %s: %s", certFile, err.Error()) + return fmt.Errorf("error writing CA certificate to %s: %s", certFile, err.Error()) } err = writeFileWithDir(keyFile, privateKeyPem.Bytes(), 0400) if err != nil { - return fmt.Errorf("error writing certificate private key to %s: %s", keyFile, err.Error()) + return fmt.Errorf("error writing CA's private key to %s: %s", keyFile, err.Error()) } return nil } -func fileExists(path string, force bool) bool { - if _, err := os.Stat(path); !os.IsNotExist(err) && !force { - return true - } - return false -} - func readCertificateFromFile(path string) (*x509.Certificate, error) { pemBytes, err := os.ReadFile(path) if err != nil { @@ -124,3 +115,19 @@ func readRSAKeyFromFile(path string) (*rsa.PrivateKey, error) { } return key, nil } + +func checkCertificatesLocationWithForce(dir, certificateName string, force bool) error { + // Throw an error if the path for the CA and key certificates already + // exists and the 'force' flag is not set. + + checkFile := func(ext string) bool { + _, err := os.Stat(filepath.Join(dir, certificateName+ext)) + return !os.IsNotExist(err) + } + + if !force && (checkFile(".key") || checkFile(".crt")) { + return fmt.Errorf("existing files would be overwritten. Use -force to proceed") + } + + return nil +} diff --git a/certificates/common_test.go b/certificates/common_test.go index 1c4d9f9..76f21e3 100644 --- a/certificates/common_test.go +++ b/certificates/common_test.go @@ -3,42 +3,54 @@ package certificates import ( "crypto/rsa" "crypto/x509" + "fmt" "github.com/stretchr/testify/assert" - "os" - "path" + "path/filepath" + "regexp" + "strings" "testing" ) -func assertFilesExist(t *testing.T, files ...string) { - for _, file := range files { - _, err := os.Stat(file) - assert.False(t, os.IsNotExist(err)) - } -} +func extractErrors(errorMessage string) []string { + // Sometimes errors are shown in a multi-line format (multierror.Append), so we need to extract them and return them + // as a list. However, this method can be used with single line errors as well and will return a list with a single + // element. Also perform some basic cleanup of the error message (TrimSpace). -func generateAndAssertCACert(t *testing.T, years int, days int, outputDirCa string, force bool) (*x509.Certificate, *rsa.PrivateKey) { - certificateError := generateCACertificate(years, days, outputDirCa, nil, nil, force) - assert.NoError(t, certificateError) + var errors []string - certFilePath := path.Join(outputDirCa, "ca.crt") - keyFilePath := path.Join(outputDirCa, "ca.key") - assertFilesExist(t, certFilePath, keyFilePath) + // Pattern for multi-line errors + multiLinePattern := regexp.MustCompile(`\* (.+)`) + multiLineMatches := multiLinePattern.FindAllStringSubmatch(errorMessage, -1) - caCertificate, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - caPrivateKey, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) + if len(multiLineMatches) > 0 { + for _, match := range multiLineMatches { + if len(match) > 1 { + errors = append(errors, strings.TrimSpace(match[1])) + } + } + } else { + // Single line error + cleanedError := strings.TrimSpace(errorMessage) + errors = append(errors, cleanedError) + } - return caCertificate, caPrivateKey + return errors } -func cleanupDirsForTest(t *testing.T, dirs ...string) { - cleanupDirs := func() { - for _, dir := range dirs { - os.RemoveAll(dir) - } - } +func readAndDecodeCertificateAndKey(t *testing.T, dir, name string) (*x509.Certificate, *rsa.PrivateKey) { + // In the test suite, we often need to verify that a certificate and key pair exist in a given directory. + // This is usually carried out after a call to the create_ca or create_node commands. This method reads the certificate + // and key from the given directory and returns them. It will throw an error if the certificate or key cannot be + // read from the given directory. + + certPath := filepath.Join(dir, fmt.Sprintf("%s.crt", name)) + keyPath := filepath.Join(dir, fmt.Sprintf("%s.key", name)) + + ca, caErr := readCertificateFromFile(certPath) + assert.NoError(t, caErr) + + key, keyErr := readRSAKeyFromFile(keyPath) + assert.NoError(t, keyErr) - cleanupDirs() - t.Cleanup(cleanupDirs) + return ca, key } diff --git a/certificates/create_ca.go b/certificates/create_ca.go index a7d3337..867f410 100644 --- a/certificates/create_ca.go +++ b/certificates/create_ca.go @@ -10,9 +10,6 @@ import ( "errors" "flag" "fmt" - "path" - "strings" - "text/tabwriter" "time" "github.com/hashicorp/go-multierror" @@ -20,7 +17,9 @@ import ( ) type CreateCA struct { - Ui cli.Ui + Ui cli.Ui + Config CreateCAArguments + Flags *flag.FlagSet } type CreateCAArguments struct { @@ -28,33 +27,38 @@ type CreateCAArguments struct { OutputDir string `yaml:"out"` CACertificatePath string `yaml:"ca-certificate"` CAKeyPath string `yaml:"ca-key"` + Name string `yaml:"name"` Force bool `yaml:"force"` } -func (c *CreateCA) Run(args []string) int { - var config CreateCAArguments - - flags := flag.NewFlagSet("create_ca", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "./ca", "The output directory") - flags.StringVar(&config.CACertificatePath, "ca-certificate", "", "the path to a CA certificate file") - flags.StringVar(&config.CAKeyPath, "ca-key", "", "the path to a CA key file") - flags.BoolVar(&config.Force, "force", false, forceOption) +func NewCreateCA(ui cli.Ui) *CreateCA { + c := &CreateCA{Ui: ui} + + c.Flags = flag.NewFlagSet("create_ca", flag.ContinueOnError) + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "./ca", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "", CaPathFlagUsage) + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "", CaKeyFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "ca", NameFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} - if err := flags.Parse(args); err != nil { +func (c *CreateCA) Run(args []string) int { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } - caCertPathLen := len(config.CACertificatePath) - caKeyPathLen := len(config.CAKeyPath) + caCertPathLen := len(c.Config.CACertificatePath) + caKeyPathLen := len(c.Config.CAKeyPath) if (caCertPathLen > 0 && caKeyPathLen == 0) || (caKeyPathLen > 0 && caCertPathLen == 0) { - multierror.Append(validationErrors, errors.New("both -ca-certificate and -ca-key options are required")) + _ = multierror.Append(validationErrors, errors.New("both -ca-certificate and -ca-key options are required")) } if validationErrors.ErrorOrNil() != nil { @@ -62,14 +66,9 @@ func (c *CreateCA) Run(args []string) int { return 1 } - // check if certificates already exist - if fileExists(path.Join(config.OutputDir, "ca.key"), config.Force) { - c.Ui.Error(ErrFileExists) - return 1 - } - - if fileExists(path.Join(config.OutputDir, "ca.crt"), config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(c.Config.OutputDir, c.Config.Name, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -77,8 +76,8 @@ func (c *CreateCA) Run(args []string) int { years := 5 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } @@ -86,13 +85,13 @@ func (c *CreateCA) Run(args []string) int { var caKey *rsa.PrivateKey var err error if caCertPathLen > 0 { - caCert, err = readCertificateFromFile(config.CACertificatePath) + caCert, err = readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err = readRSAKeyFromFile(config.CAKeyPath) + caKey, err = readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) @@ -100,10 +99,11 @@ func (c *CreateCA) Run(args []string) int { } } - outputDir := config.OutputDir - err = generateCACertificate(years, days, outputDir, caCert, caKey, config.Force) + outputDir := c.Config.OutputDir + err = generateCACertificate(years, days, outputDir, c.Config.Name, caCert, caKey, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) + return 1 } else { if isBoringEnabled() { c.Ui.Output(fmt.Sprintf("A CA certificate & key file have been generated in the '%s' directory (FIPS mode enabled).", outputDir)) @@ -111,10 +111,11 @@ func (c *CreateCA) Run(args []string) int { c.Ui.Output(fmt.Sprintf("A CA certificate & key file have been generated in the '%s' directory.", outputDir)) } } + return 0 } -func generateCACertificate(years int, days int, outputDir string, caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, force bool) error { +func generateCACertificate(years int, days int, outputDir string, name string, caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, force bool) error { serialNumber, err := generateSerialNumber(128) if err != nil { return fmt.Errorf("could not generate 128-bit serial number: %s", err.Error()) @@ -184,31 +185,17 @@ func generateCACertificate(years int, days int, outputDir string, caCert *x509.C return fmt.Errorf("could not encode certificate to PEM format: %s", err.Error()) } - err = writeCertAndKey(outputDir, "ca", certPem, privateKeyPem, force) + err = writeCertAndKey(outputDir, name, certPem, privateKeyPem, force) return err } func (c *CreateCA) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_ca [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "days", "The validity period of the certificate in days (default: 5 years).") - writeHelpOption(w, "out", "The output directory (default: ./ca).") - writeHelpOption(w, "ca-certificate", "The path to a CA certificate file for creating an intermediate CA certificate.") - writeHelpOption(w, "ca-key", "The path to a CA key file for creating an intermediate CA certificate.") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } - func (c *CreateCA) Synopsis() string { return "Generate a root/intermediate CA TLS certificate to be used with EventStoreDB" } diff --git a/certificates/create_ca_test.go b/certificates/create_ca_test.go index 36ab422..81e0417 100644 --- a/certificates/create_ca_test.go +++ b/certificates/create_ca_test.go @@ -1,77 +1,168 @@ package certificates import ( - "crypto/rsa" - "crypto/x509" - "path" - "testing" - + "bytes" + "fmt" + "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" ) -func setupTestEnvironmentForCaTests(t *testing.T) (years int, days int, outputDir string, caCert *x509.Certificate, caKey *rsa.PrivateKey) { - years = 1 - days = 0 - outputDir = "./ca" - caCert = nil - caKey = nil +func TestCreateCACertificate(t *testing.T) { + t.Run("TestCreateCACertificate_NominalCase_ShouldSucceed", TestCreateCACertificate_NominalCase_ShouldSucceed) + t.Run("TestCreateCACertificate_DifferentOut_ShouldSucceed", TestCreateCACertificate_DifferentOut_ShouldSucceed) + t.Run("TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates", TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates) + t.Run("TestCreateCACertificate_WithForceFlag_ShouldRegenerate", TestCreateCACertificate_WithForceFlag_ShouldRegenerate) + t.Run("TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail", TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail) +} + +func TestCreateCACertificate_NominalCase_ShouldSucceed(t *testing.T) { + // Create CA certificate and key without any additional parameters. + + t.Parallel() + + cleanup, tempDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{ + OutputDir: "./ca", + }) + defer cleanup() + + var args []string + + result := createCa.Run(args) + assert.Equal(t, 0, result, "creat-ca should pass without any additional parameters") + + assert.FileExists(t, filepath.Join("./ca", "ca.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join("./ca", "ca.key"), "CA key should exist") - cleanupDirsForTest(t, outputDir) - return + cert, err := readCertificateFromFile(filepath.Join(tempDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") + + // The certificate should be valid for 5 year + expectedNotAfter := time.Now().AddDate(5, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") } -func testGenerateCACertificate(t *testing.T, years int, days int, outputDir string, caCert *x509.Certificate, caKey *rsa.PrivateKey, force bool) { - err := generateCACertificate(years, days, outputDir, caCert, caKey, force) - assert.NoError(t, err, "Expected no error in nominal case") +func TestCreateCACertificate_DifferentOut_ShouldSucceed(t *testing.T) { + // Create certificate with a different output directory. - certFilePath := path.Join(outputDir, "ca.crt") - keyFilePath := path.Join(outputDir, "ca.key") + t.Parallel() - certFile, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - keyFile, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + defer cleanup() - err = generateCACertificate(years, days, outputDir, caCert, caKey, force) - if !force { - assert.Error(t, err, "Expected an error when directory exists and override is false") - } else { - assert.NoError(t, err, "Expected no error when directory exists and override is true") - } + args := []string{"-out", filepath.Join(tempCaDir, "my-custom-dir")} + + result := createCa.Run(args) + assert.Equal(t, 0, result, "creat-ca should pass with a different output") + + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-dir", "ca.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-dir", "ca.key"), "CA key should exist") +} + +func TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates(t *testing.T) { + // Create CA certificate and key with the name parameter. + // 1. It creates a certificate with the name parameter + // 2. The CA certificate and key should be named with the name parameter + + t.Parallel() + + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + defer cleanup() + + args := []string{"-out", tempCaDir, "-name", "my-custom-name"} + + result := createCa.Run(args) + assert.Equal(t, 0, result, "creat-ca should create a certificate with a different name") - certFileAfter, err := readCertificateFromFile(certFilePath) - assert.NoError(t, err) - keyFileAfter, err := readRSAKeyFromFile(keyFilePath) - assert.NoError(t, err) - - if !force { - assert.Equal(t, certFile, certFileAfter, "Expected CA certificate to be the same") - assert.Equal(t, keyFile, keyFileAfter, "Expected CA key to be the same") - } else { - assert.NotEqual(t, certFile, certFileAfter, "Expected CA certificate to be different") - assert.NotEqual(t, keyFile, keyFileAfter, "Expected CA key to be different") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-name.crt"), "CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCaDir, "my-custom-name.key"), "CA certificate should exist") +} + +func TestCreateCACertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + // Creation of a CA certificate with the force flag. + // 1. It first creates a certificate + // 2. Attempt to recreate the certificate with the force flag + // 3. Check that the content of the files are different + + t.Parallel() + + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + defer cleanup() + + // Create a CA certificate + result := createCa.Run([]string{"-out", tempCaDir}) + assert.Equal(t, 0, result, fmt.Sprintf("creat-ca should pass and create a certificate at %s", tempCaDir)) + + // Read the content of the key and crt files + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, tempCaDir, "ca") + + // Try to create a CA certificate again with the force flag and override the existing one + args := []string{"-out", tempCaDir, "-force"} + result = createCa.Run(args) + assert.Equal(t, 0, result, fmt.Sprintf("creat-ca should pass and override certificate at %s", tempCaDir)) + + // Read the content of the key and crt files generated from the config file + newCaCert, newKeyCert := readAndDecodeCertificateAndKey(t, tempCaDir, "ca") + + // Check that the content of the files are different + assert.NotEqual(t, originalCaCert, newCaCert, "The content of the CA certificate should be different") + assert.NotEqual(t, originalKeyCert, newKeyCert, "The content of the CA key should be different") +} + +func TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail(t *testing.T) { + // Create intermediate certificate without root certificate. + // 1. It creates an intermediate certificate without root certificate + // 2. It should return an error + + t.Parallel() + + cleanup, tempCaDir, _, errorBuffer, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + defer cleanup() + + args := []string{ + "-out", tempCaDir, + "-ca-certificate", "unknown", + "-ca-key", "unknown", } + result := createCa.Run(args) + assert.Equal(t, 1, result, "creat-ca should fail without a root certificate") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors), "Expected 1 error") + assert.Equal(t, "error reading file: open unknown: no such file or directory", errors[0]) +} + +type TestEnvParams struct { + OutputDir string } -func TestGenerateCACertificate(t *testing.T) { - t.Run("nominal-case", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) +func setupCreateCaTestEnvironment(t *testing.T, params *TestEnvParams) (cleanupFunc func(), tempDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCa *CreateCA) { + tempDir = params.OutputDir - err := generateCACertificate(years, days, outputDir, caCert, caKey, false) + if tempDir == "" { + var err error + tempDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + } - assert.NoError(t, err, "Expected no error in nominal case") + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) - assert.FileExists(t, path.Join(outputDir, "ca.crt"), "CA certificate should exist") - assert.FileExists(t, path.Join(outputDir, "ca.key"), "CA key should exist") + createCa = NewCreateCA(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("directory-exists", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) - testGenerateCACertificate(t, years, days, outputDir, caCert, caKey, false) - }) + cleanupFunc = func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Logf("Failed to remove temp directory (%s): %s", tempDir, err) + } + } - t.Run("directory-exists-force", func(t *testing.T) { - years, days, outputDir, caCert, caKey := setupTestEnvironmentForCaTests(t) - testGenerateCACertificate(t, years, days, outputDir, caCert, caKey, true) - }) + return cleanupFunc, tempDir, outputBuffer, errorBuffer, createCa } diff --git a/certificates/create_certs.go b/certificates/create_certs.go index 4594eeb..56f4790 100644 --- a/certificates/create_certs.go +++ b/certificates/create_certs.go @@ -4,19 +4,18 @@ import ( "bytes" "flag" "fmt" + "github.com/mitchellh/cli" + "gopkg.in/yaml.v3" "os" - "path" "reflect" "strings" "sync" - "text/tabwriter" - - "github.com/mitchellh/cli" - "gopkg.in/yaml.v3" ) type CreateCertificates struct { - Ui cli.Ui + Ui cli.Ui + Config CreateCertificateArguments + Flags *flag.FlagSet } type CreateCertificateArguments struct { @@ -31,17 +30,22 @@ type Config struct { } `yaml:"certificates"` } -func (c *CreateCertificates) Run(args []string) int { - var arguments CreateCertificateArguments - flags := flag.NewFlagSet("create_certs", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&arguments.ConfigPath, "config-file", "./certs.yml", "The config yml file") - flags.BoolVar(&arguments.Force, "force", false, forceOption) +func NewCreateCerts(ui cli.Ui) *CreateCertificates { + c := &CreateCertificates{Ui: ui} - if err := flags.Parse(args); err != nil { + c.Flags = flag.NewFlagSet("create_certs", flag.ContinueOnError) + c.Flags.StringVar(&c.Config.ConfigPath, "config-file", "./certs.yml", "The config yml file") + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} + +func (c *CreateCertificates) Run(args []string) int { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } - configData, err := os.ReadFile(arguments.ConfigPath) + + configData, err := os.ReadFile(c.Config.ConfigPath) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -53,65 +57,55 @@ func (c *CreateCertificates) Run(args []string) int { return 1 } - if err := c.checkPaths(config, arguments.Force); err { - c.Ui.Error(ErrFileExists) + certErr := c.checkPaths(config, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } - if c.generateCaCerts(config, arguments.Force) != 0 || c.generateNodes(config, arguments.Force) != 0 { + if c.generateCaCerts(config, c.Config.Force) != 0 || c.generateNodes(config, c.Config.Force) != 0 { return 1 } return 0 } -func (c *CreateCertificates) checkPaths(config Config, force bool) bool { - // If any certs file exists and the force flag isn't provided, it returns an - // error. Otherwise, it returns false, indicating that certificate generation - // can proceed safely. - - var errorMutex sync.Mutex - var error bool +func (c *CreateCertificates) checkPaths(config Config, force bool) error { + var once sync.Once + var certError error var wg sync.WaitGroup - checkFile := func(filePath string) { + checkCertFiles := func(certificateName, dir string) { defer wg.Done() - if fileExists(filePath, force) { - errorMutex.Lock() - error = true - errorMutex.Unlock() + if err := checkCertificatesLocationWithForce(dir, certificateName, force); err != nil { + once.Do(func() { + certError = err + }) } } // Check CA certificate and key paths for _, caCert := range config.Certificates.CaCerts { - wg.Add(2) - go checkFile(caCert.CACertificatePath) - go checkFile(caCert.CAKeyPath) + wg.Add(1) + go checkCertFiles(caCert.Name, caCert.OutputDir) } // Check Node certificate and key paths for _, node := range config.Certificates.Nodes { - wg.Add(4) - go checkFile(node.CACertificatePath) - go checkFile(node.CAKeyPath) - go checkFile(path.Join(node.OutputDir, "node.crt")) - go checkFile(path.Join(node.OutputDir, "node.key")) + wg.Add(1) + go checkCertFiles(node.Name, node.OutputDir) } wg.Wait() - return error + return certError } - func (c *CreateCertificates) generateNodes(config Config, force bool) int { for _, node := range config.Certificates.Nodes { node.Force = force - createNode := CreateNode{ - Ui: &cli.ColoredUi{ - Ui: c.Ui, - OutputColor: cli.UiColorBlue, - }, - } + createNode := NewCreateNode(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) if createNode.Run(toArguments(node)) != 0 { return 1 } @@ -120,19 +114,16 @@ func (c *CreateCertificates) generateNodes(config Config, force bool) int { } func (c *CreateCertificates) generateCaCerts(config Config, force bool) int { - coloredUI := &cli.ColoredUi{ - Ui: c.Ui, - OutputColor: cli.UiColorBlue, - } - for _, caCert := range config.Certificates.CaCerts { caCert.Force = force - caCreator := CreateCA{Ui: coloredUI} + caCreator := NewCreateCA(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) if caCreator.Run(toArguments(caCert)) != 0 { return 1 } } - return 0 } @@ -157,20 +148,10 @@ func toArguments(config interface{}) []string { } func (c *CreateCertificates) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_certs [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "config-file", "The path to the yml config file.") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateCertificates) Synopsis() string { diff --git a/certificates/create_certs_test.go b/certificates/create_certs_test.go new file mode 100644 index 0000000..2c45b4e --- /dev/null +++ b/certificates/create_certs_test.go @@ -0,0 +1,354 @@ +package certificates + +import ( + "bytes" + "fmt" + "github.com/mitchellh/cli" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestCreateCertificates(t *testing.T) { + t.Run("TestCreateCertificates_ValidConfigFile_ShouldSucceed", TestCreateCertificates_ValidConfigFile_ShouldSucceed) + t.Run("TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail", TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail) + t.Run("TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate", TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate) + t.Run("TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates", TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates) + t.Run("TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError", TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError) +} + +func TestCreateCertificates_ValidConfigFile_ShouldSucceed(t *testing.T) { + // Create certificates from a certs.yml file + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileWithName := "certs.yml" + + // Create a certs.yml file + createConfigFile(t, tempCertsDir, certsFileWithName, certsFile, tempCertsDir) + + fmt.Println(tempCertsDir, certsFileWithName) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } +} + +func TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail(t *testing.T) { + // Create certificates from config file and should fail because the certificates already exist + // 1. Successfully create certificates from config file + // 2. Run create-certs again without the force flag + // 3. Expect an error suggesting that the certificates already exist and that the force flag should be used + + t.Parallel() + + cleanup, tempCertsDir, _, errorBuffer, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + createConfigFile(t, tempCertsDir, "certs.yml", certsFile, tempCertsDir) + + args := []string{ + "-config-file", tempCertsDir + "/certs.yml", + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed the first time it is run since the certificates do not exist") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } + + // Try to generate the certificates again and expect and error + result = createCerts.Run(args) + assert.Equal(t, 1, result, "The create-certs command should fail the second time it is run since the certificates already exist") + errors := extractErrors(errorBuffer.String()) + + assert.Equal(t, 1, len(errors), "Expected 1 error") + assert.Equal(t, "existing files would be overwritten. Use -force to proceed", errors[0]) +} + +func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t *testing.T) { + // Create certificates from a certs.yml file with the force flag + // Expect all certificates to be regenerated and different from the original ones + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileWithName := "certs.yml" + + // Create a certs.yml file + createConfigFile(t, tempCertsDir, certsFileWithName, certsFile, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed the first time it is run since the certificates do not exist") + + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "root_ca", "ca.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "intermediate_ca", "ca.key"), "Intermediate certificate key should exist") + + nodes := []string{"node1", "node2", "node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) + } + + // Read the content of the key and crt files generated from the config file + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") + originalIntermediateCaCert, originalIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + + originalCerts := make(map[string][2]interface{}) + + for _, node := range nodes { + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, node), "node") + originalCerts[node] = [2]interface{}{originalCaCert, originalKeyCert} + } + + args = []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileWithName), + "-force", + } + + result = createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should succeed with the force flag and "+ + "override the existing certificates defined in the config file") + + newRootCaCert, newRootCaKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") + newIntermediateCaCert, newIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + + assert.NotEqual(t, originalCaCert, newRootCaCert, "Root CA certificate should be regenerated") + assert.NotEqual(t, originalKeyCert, newRootCaKey, "Root CA key should be regenerated") + + assert.NotEqual(t, originalIntermediateCaCert, newIntermediateCaCert, "Intermediate CA certificate should be regenerated") + assert.NotEqual(t, originalIntermediateKeyCert, newIntermediateKeyCert, "Intermediate CA key should be regenerated") + + for _, node := range nodes { + newCAHash, newKeyHash := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, node), "node") + assert.NotEqual(t, originalCerts[node][0], newCAHash, fmt.Sprintf("%s certificate should be regenerated", node)) + assert.NotEqual(t, originalCerts[node][1], newKeyHash, fmt.Sprintf("%s certificate key should be regenerated", node)) + } +} + +func TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertificates(t *testing.T) { + // Create certificates from a certs.yml file with the name parameter + // Expect all certificates to be named with the name parameter + + t.Parallel() + + cleanup, tempCertsDir, _, _, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileName := "certs-with-name.yml" + + createConfigFile(t, tempCertsDir, certsFileName, certsFileWithName, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileName), + } + + result := createCerts.Run(args) + assert.Equal(t, 0, result, "The create-certs command should create certificates with custom names") + + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_root", "custom_root.crt"), "Root CA certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_root", "custom_root.key"), "Root CA key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.crt"), "Intermediate certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.key"), "Intermediate certificate key should exist") + + nodes := []string{"custom_node1", "custom_node2", "custom_node3"} + for _, node := range nodes { + assert.FileExists(t, filepath.Join(tempCertsDir, node, fmt.Sprintf("%s.crt", node)), fmt.Sprintf("%s certificate should exist", node)) + assert.FileExists(t, filepath.Join(tempCertsDir, node, fmt.Sprintf("%s.key", node)), fmt.Sprintf("%s certificate key should exist", node)) + } +} + +func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T) { + // An invalid path is defined at ca-certificate in the config. + // The intermediate certificate uses an invalid path for the root CA certificate. + // This should result in an error suggesting that ca.crt is not found. + + t.Parallel() + + cleanup, tempCertsDir, _, errorBuffer, createCerts := setupCertificateTestEnvironment(t) + defer cleanup() + + certsFileName := "certs.yml" + + createConfigFile(t, tempCertsDir, certsFileName, certsFileWithInvalidPath, tempCertsDir) + + args := []string{ + "-config-file", filepath.Join(tempCertsDir, certsFileName), + } + + result := createCerts.Run(args) + assert.Equal(t, 1, result, "The create-certs command should fail with code 1 when an invalid path is defined in the config") + + errors := extractErrors(errorBuffer.String()) + + assert.Equal(t, 1, len(errors), "Expected 1 error") + + expectedErrorMessage := fmt.Sprintf("error reading file: open %s: no such file or directory", filepath.Join(tempCertsDir, "invalid_root_ca", "ca.crt")) + assert.Equal(t, errors[0], expectedErrorMessage) + + // The root CA will be created + assert.DirExists(t, filepath.Join(tempCertsDir, "root_ca")) + + // Intermediate and node1 will not be created + assert.NoDirExists(t, filepath.Join(tempCertsDir, "intermediate_ca"), "Intermediate certificate should not exist") + assert.NoDirExists(t, filepath.Join(tempCertsDir, "node1"), "Intermediate certificate key should not exist") +} + +// Valid certificate file +var certsFile = `certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com"` + +// Invalid path defined at ca-certificate in the config +var certsFileWithInvalidPath = `certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./invalid_root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com"` + +// Each certificate have a name parameter +var certsFileWithName = `certificates: + ca-certs: + - out: "./custom_root" + name: "custom_root" + - out: "./custom_intermediate" + name: "custom_intermediate" + ca-certificate: "./custom_root/custom_root.crt" + ca-key: "./custom_root/custom_root.key" + days: 5 + node-certs: + - out: "./custom_node1" + name: "custom_node1" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./custom_node2" + name: "custom_node2" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./custom_node3" + name: "custom_node3" + ca-certificate: "./custom_intermediate/custom_intermediate.crt" + ca-key: "./custom_intermediate/custom_intermediate.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com"` + +func setupCertificateTestEnvironment(t *testing.T) (cleanupFunc func(), tempCertsDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCerts *CreateCertificates) { + tempCertsDir, err := os.MkdirTemp(os.TempDir(), "certs-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createCerts = NewCreateCerts(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, + }) + + cleanupFunc = func() { + if err := os.RemoveAll(tempCertsDir); err != nil { + t.Logf("Failed to remove temp directory (%s): %s", tempCertsDir, err) + } + } + + return cleanupFunc, tempCertsDir, outputBuffer, errorBuffer, createCerts +} + +func createConfigFile(t *testing.T, dirPath string, fileName string, content string, newParentDir string) { + updatedContent := strings.ReplaceAll(content, "./", fmt.Sprintf("%s/", newParentDir)) + + filePath := filepath.Join(dirPath, fileName) + + // Create the directory if it does not exist + if _, err := os.Stat(dirPath); os.IsNotExist(err) { + err := os.MkdirAll(dirPath, 0755) + if err != nil { + t.Errorf("Error creating directory: %s", err) + } + } + + f, err := os.Create(filePath) + if err != nil { + panic(err) + } + + defer func(f *os.File) { + err := f.Close() + if err != nil { + t.Error(err) + } + }(f) + + _, err = f.WriteString(updatedContent) + if err != nil { + panic(err) + } +} diff --git a/certificates/create_node.go b/certificates/create_node.go index bf05ea6..5c04559 100644 --- a/certificates/create_node.go +++ b/certificates/create_node.go @@ -12,10 +12,8 @@ import ( "fmt" "net" "os" - "path" "strconv" "strings" - "text/tabwriter" "time" multierror "github.com/hashicorp/go-multierror" @@ -23,7 +21,9 @@ import ( ) type CreateNode struct { - Ui cli.Ui + Ui cli.Ui + Flags *flag.FlagSet + Config CreateNodeArguments } type CreateNodeArguments struct { @@ -34,9 +34,26 @@ type CreateNodeArguments struct { Days int `yaml:"days"` OutputDir string `yaml:"out"` CommonName string `yaml:"common-name"` + Name string `yaml:"name"` Force bool `yaml:"force"` } +func NewCreateNode(ui cli.Ui) *CreateNode { + c := &CreateNode{Ui: ui} + + c.Flags = flag.NewFlagSet("create_node", flag.ContinueOnError) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaPathFlagUsage) + c.Flags.StringVar(&c.Config.CommonName, "common-name", "eventstoredb-node", "the certificate subject common name") + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) + c.Flags.StringVar(&c.Config.IPAddresses, "ip-addresses", "", "comma-separated list of IP addresses of the node") + c.Flags.StringVar(&c.Config.DNSNames, "dns-names", "", "comma-separated list of DNS names of the node") + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "node", NameFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + return c +} + func parseIPAddresses(ipAddresses string) ([]net.IP, error) { if len(ipAddresses) == 0 { return []net.IP{}, nil @@ -62,7 +79,7 @@ func parseDNSNames(dnsNames string) ([]string, error) { return dns, nil } -func getNodeOutputDirectory() (string, error) { +func getOutputDirectory() (string, error) { for i := 1; i <= 100; i++ { dir := "node" + strconv.Itoa(i) if _, err := os.Stat(dir); os.IsNotExist(err) { @@ -73,38 +90,26 @@ func getNodeOutputDirectory() (string, error) { } func (c *CreateNode) Run(args []string) int { - var config CreateNodeArguments - - flags := flag.NewFlagSet("create_node", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&config.CACertificatePath, "ca-certificate", "./ca/ca.crt", "the path to the CA certificate file") - flags.StringVar(&config.CommonName, "common-name", "eventstoredb-node", "the certificate subject common name") - flags.StringVar(&config.CAKeyPath, "ca-key", "./ca/ca.key", "the path to the CA key file") - flags.StringVar(&config.IPAddresses, "ip-addresses", "", "comma-separated list of IP addresses of the node") - flags.StringVar(&config.DNSNames, "dns-names", "", "comma-separated list of DNS names of the node") - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "", "The output directory") - flags.BoolVar(&config.Force, "force", false, forceOption) - - if err := flags.Parse(args); err != nil { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if len(config.CACertificatePath) == 0 { - multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) + if len(c.Config.CACertificatePath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) } - if len(config.CAKeyPath) == 0 { - multierror.Append(validationErrors, errors.New("ca-key is a required field")) + if len(c.Config.CAKeyPath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-key is a required field")) } - if len(config.IPAddresses) == 0 && len(config.DNSNames) == 0 { - multierror.Append(validationErrors, errors.New("at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names")) + if len(c.Config.IPAddresses) == 0 && len(c.Config.DNSNames) == 0 { + _ = multierror.Append(validationErrors, errors.New("at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names")) } - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } if validationErrors.ErrorOrNil() != nil { @@ -112,36 +117,36 @@ func (c *CreateNode) Run(args []string) int { return 1 } - caCert, err := readCertificateFromFile(config.CACertificatePath) + caCert, err := readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err := readRSAKeyFromFile(config.CAKeyPath) + caKey, err := readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) return 1 } - ips, err := parseIPAddresses(config.IPAddresses) + ips, err := parseIPAddresses(c.Config.IPAddresses) if err != nil { c.Ui.Error(err.Error()) return 1 } - dnsNames, err := parseDNSNames(config.DNSNames) + dnsNames, err := parseDNSNames(c.Config.DNSNames) if err != nil { c.Ui.Error(err.Error()) return 1 } - outputDir := config.OutputDir - outputBaseFileName := "node" + outputDir := c.Config.OutputDir + outputBaseFileName := c.Config.Name if len(outputDir) == 0 { - outputDir, err = getNodeOutputDirectory() + outputDir, err = getOutputDirectory() if err != nil { c.Ui.Error(err.Error()) return 1 @@ -149,17 +154,9 @@ func (c *CreateNode) Run(args []string) int { outputBaseFileName = outputDir } - // check if certificates already exist - keyPath := path.Join(config.OutputDir, fmt.Sprintf("%s.key", outputBaseFileName)) - crtPath := path.Join(config.OutputDir, fmt.Sprintf("%s.crt", outputBaseFileName)) - - if fileExists(keyPath, config.Force) { - c.Ui.Error(ErrFileExists) - return 1 - } - - if fileExists(crtPath, config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(outputDir, outputBaseFileName, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -167,12 +164,12 @@ func (c *CreateNode) Run(args []string) int { years := 1 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } - err = generateNodeCertificate(caCert, caKey, ips, dnsNames, years, days, outputDir, outputBaseFileName, config.CommonName, config.Force) + err = generateNodeCertificate(caCert, caKey, ips, dnsNames, years, days, outputDir, outputBaseFileName, c.Config.CommonName, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -187,7 +184,18 @@ func (c *CreateNode) Run(args []string) int { return 0 } -func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.PrivateKey, ips []net.IP, dnsNames []string, years int, days int, outputDir string, outputBaseFileName string, commonName string, force bool) error { +func generateNodeCertificate( + caCert *x509.Certificate, + caPrivateKey *rsa.PrivateKey, + ips []net.IP, + dnsNames []string, + years int, + days int, + outputDir string, + outputBaseFileName string, + commonName string, + force bool, +) error { serialNumber, err := generateSerialNumber(128) if err != nil { return fmt.Errorf("could not generate 128-bit serial number: %s", err.Error()) @@ -219,7 +227,7 @@ func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.Private } privateKeyPem := new(bytes.Buffer) - pem.Encode(privateKeyPem, &pem.Block{ + err = pem.Encode(privateKeyPem, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) @@ -247,28 +255,10 @@ func generateNodeCertificate(caCert *x509.Certificate, caPrivateKey *rsa.Private } func (c *CreateNode) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) // 2 spaces minimum gap between columns - - fmt.Fprintln(w, "Usage: create_node [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "ca-certificate", "The path to the CA certificate file (default: ./ca/ca.crt).") - writeHelpOption(w, "ca-key", "The path to the CA key file (default: ./ca/ca.key).") - writeHelpOption(w, "days", "The validity period of the certificates in days (default: 1 year).") - writeHelpOption(w, "out", "The output directory (default: ./nodeX where X is an auto-generated number).") - writeHelpOption(w, "ip-addresses", "Comma-separated list of IP addresses of the node.") - writeHelpOption(w, "dns-names", "Comma-separated list of DNS names of the node.") - writeHelpOption(w, "common-name", "The certificate subject common name.") - writeHelpOption(w, "force", forceOption) - - fmt.Fprintln(w, "\nAt least one IP address or DNS name needs to be specified.") - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateNode) Synopsis() string { diff --git a/certificates/create_node_test.go b/certificates/create_node_test.go index 278cc01..fa05404 100644 --- a/certificates/create_node_test.go +++ b/certificates/create_node_test.go @@ -1,101 +1,157 @@ package certificates import ( + "bytes" "crypto/x509" - "path" - "testing" - + "fmt" + "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" + "time" ) -func setupTestEnvironmentForNodeTests(t *testing.T) (years int, days int, outputDirCa string, outputDirNode string, nodeCertFileName string, ipAddresses string, commonName string, dnsNames []string) { - years = 1 - days = 0 - outputDirCa = "./ca" - outputDirNode = "./node" - nodeCertFileName = "node" - ipAddresses = "127.0.0.1" - commonName = "EventStoreDB" - dnsNames = []string{"localhost"} - - cleanupDirsForTest(t, outputDirCa, outputDirNode) - return +func TestCreateNodeCertificate(t *testing.T) { + t.Run("TestCreateNodeCertificate_NoParams_ShouldFail", TestCreateNodeCertificate_NoParams_ShouldFail) + t.Run("TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed) + t.Run("TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate", TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate) +} + +func TestCreateNodeCertificate_NoParams_ShouldFail(t *testing.T) { + // Create a node certificate with no params + + t.Parallel() + + cleanup, _, _, _, errorBuffer, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + var args []string + result := createNode.Run(args) + assert.Equal(t, 1, result, "The 'create-node' operation should fail due to the absence of required parameters.") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names", errors[0]) + assert.Equal(t, 1, result) +} + +func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + args := []string{ + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + } + if result := createNode.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + assert.FileExists(t, filepath.Join(tempNodeDir, "node.crt"), "Node certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "node.key"), "Node key should exist") + + cert, err := readCertificateFromFile(filepath.Join(tempNodeDir, "node.crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") + + // The certificate should be valid for 1 year + expectedNotAfter := time.Now().AddDate(1, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") + + // Now we verify if the certificate is signed by the provided root CA + caCert, err := readCertificateFromFile(filepath.Join(tempCaDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse CA certificate file") + + roots := x509.NewCertPool() + roots.AddCert(caCert) + + _, err = cert.Verify(x509.VerifyOptions{Roots: roots}) + assert.NoError(t, err, "Node certificate should be signed by the provided root CA") } -func TestGenerateNodeCertificate(t *testing.T) { +func TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + // Create a node certificate with the force flag - t.Run("nominal-case", func(t *testing.T) { - years, days, outputDirCa, outputDirNode, nodeCertFileName, ipAddresses, commonName, dnsNames := setupTestEnvironmentForNodeTests(t) + t.Parallel() - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - ips, err := parseIPAddresses(ipAddresses) - assert.NoError(t, err) + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() - certificateError := generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - assert.NoError(t, certificateError) + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + } - nodeCertPath := path.Join(outputDirNode, nodeCertFileName+".crt") - nodeKeyPath := path.Join(outputDirNode, nodeCertFileName+".key") - assertFilesExist(t, nodeCertPath, nodeKeyPath) + result := createNode.Run(args) + assert.Equal(t, 0, result, "The 'create-node' operation without the --force flag should succeed the first time") - nodeCertificate, err := readCertificateFromFile(nodeCertPath) - assert.NoError(t, err) + assert.FileExists(t, filepath.Join(tempNodeDir, "node.crt"), "Node certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "node.key"), "Node key should exist") - // verify the subject - assert.Equal(t, "CN=EventStoreDB", nodeCertificate.Subject.String()) + // Read the content of the key and crt files + originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") - // verify the issuer - assert.Equal(t, caCertificate.Issuer.String(), nodeCertificate.Issuer.String()) + // Create the node certificate again with the force flag + updatedArgs := append(args, "-force") - // verify the EKUs - assert.Equal(t, 2, len(nodeCertificate.ExtKeyUsage)) - assert.Equal(t, x509.ExtKeyUsageClientAuth, nodeCertificate.ExtKeyUsage[0]) - assert.Equal(t, x509.ExtKeyUsageServerAuth, nodeCertificate.ExtKeyUsage[1]) - assert.Equal(t, 0, len(nodeCertificate.UnknownExtKeyUsage)) + result = createNode.Run(updatedArgs) + assert.Equal(t, 0, result, "The 'create-node' should override the existing certificate with the --force flag") - // verify the IP SANs - assert.Equal(t, 1, len(nodeCertificate.IPAddresses)) - assert.Equal(t, "127.0.0.1", nodeCertificate.IPAddresses[0].String()) + // Read the content of the key and crt files again + newCaCert, newKeyCert := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") - // verify the DNS SANs - assert.Equal(t, 1, len(nodeCertificate.DNSNames)) - assert.Equal(t, "localhost", nodeCertificate.DNSNames[0]) + assert.NotEqual(t, originalCaCert, newCaCert, "The CA certificate should be different") + assert.NotEqual(t, originalKeyCert, newKeyCert, "The CA key should be different") +} + +func setupCreateNodeTestEnvironment(t *testing.T) (cleanupFunc func(), tempNodeDir, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createNode *CreateNode) { + var err error + + tempNodeDir, err = os.MkdirTemp(os.TempDir(), "node-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + tempCaDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createNode = NewCreateNode(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("force-flag", func(t *testing.T) { - years, days, outputDirCa, outputDirNode, nodeCertFileName, ipAddresses, commonName, dnsNames := setupTestEnvironmentForNodeTests(t) - - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - ips, err := parseIPAddresses(ipAddresses) - assert.NoError(t, err) - - nodeCertFilePath := path.Join(outputDirNode, nodeCertFileName+".crt") - nodeKeyFilePath := path.Join(outputDirNode, nodeCertFileName+".key") - - generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - nodeCertFile, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFile, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - - // try to generate again without force - err = generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, false) - assert.Error(t, err) - nodeCertFileAfter, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFileAfter, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - assert.Equal(t, nodeCertFile, nodeCertFileAfter, "Expected node certificate to be the same") - assert.Equal(t, nodeKeyFile, nodeKeyFileAfter, "Expected node key to be the same") - - // try to generate again with force - err = generateNodeCertificate(caCertificate, caPrivateKey, ips, dnsNames, years, days, outputDirNode, nodeCertFileName, commonName, true) - assert.NoError(t, err) - nodeCertFileAfterWithForce, err := readCertificateFromFile(nodeCertFilePath) - assert.NoError(t, err) - nodeKeyFileAfterWithForce, err := readRSAKeyFromFile(nodeKeyFilePath) - assert.NoError(t, err) - assert.NotEqual(t, nodeCertFileAfter, nodeCertFileAfterWithForce, "Expected node certificate to be different") - assert.NotEqual(t, nodeKeyFileAfter, nodeKeyFileAfterWithForce, "Expected node key to be different") + // We need to create a root CA file to be able to create a node certificate + createCa := NewCreateCA(&cli.BasicUi{ + Writer: new(bytes.Buffer), + ErrorWriter: new(bytes.Buffer), }) + + args := []string{"-out", tempCaDir} + if result := createCa.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + cleanupFunc = func() { + if err := os.RemoveAll(tempNodeDir); err != nil { + t.Logf("Failed to remove temp node directory (%s): %s", tempNodeDir, err) + } + if err := os.RemoveAll(tempCaDir); err != nil { + t.Logf("Failed to remove temp ca directory (%s): %s", tempCaDir, err) + } + } + + return cleanupFunc, tempNodeDir, tempCaDir, outputBuffer, errorBuffer, createNode } diff --git a/main.go b/main.go index 54b3e7a..105e867 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,10 @@ func main() { flags := flag.NewFlagSet("config", flag.ContinueOnError) if !c.IsVersion() && !c.IsHelp() { - flags.Parse(os.Args[1:]) + err := flags.Parse(os.Args[1:]) + if err != nil { + ui.Error(err.Error()) + } args = flags.Args() } @@ -38,28 +41,25 @@ func main() { c.Commands = map[string]cli.CommandFactory{ "create-ca": func() (cli.Command, error) { - return &certificates.CreateCA{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateCA(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }), nil }, "create-node": func() (cli.Command, error) { - return &certificates.CreateNode{ - Ui: &cli.ColoredUi{ + return certificates.NewCreateNode( + &cli.ColoredUi{ Ui: ui, OutputColor: cli.UiColorBlue, }, - }, nil + ), nil }, "create-certs": func() (cli.Command, error) { - return &certificates.CreateCertificates{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateCerts(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }, + ), nil }, "create-user": func() (cli.Command, error) { return &certificates.CreateUser{ diff --git a/references/certs.yml b/references/certs.yml new file mode 100644 index 0000000..67732f7 --- /dev/null +++ b/references/certs.yml @@ -0,0 +1,23 @@ +certificates: + ca-certs: + - out: "./root_ca" + - out: "./intermediate_ca" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" + days: 5 + node-certs: + - out: "./node1" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + ca-certificate: "./intermediate_ca/ca.crt" + ca-key: "./intermediate_ca/ca.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" diff --git a/references/named_certs.yml b/references/named_certs.yml new file mode 100644 index 0000000..97dae19 --- /dev/null +++ b/references/named_certs.yml @@ -0,0 +1,28 @@ +certificates: + ca-certs: + - out: "./root_ca" + name: "root" + - out: "./intermediate_ca" + name: "intermediate" + ca-certificate: "./root_ca/root.crt" + ca-key: "./root_ca/root.key" + days: 5 + node-certs: + - out: "./node1" + name: "node1" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.1,172.20.240.1" + dns-names: "localhost,eventstore-node1.localhost.com" + - out: "./node2" + name: "node2" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.2,172.20.240.2" + dns-names: "localhost,eventstore-node2.localhost.com" + - out: "./node3" + name: "node3" + ca-certificate: "./intermediate_ca/intermediate.crt" + ca-key: "./intermediate_ca/intermediate.key" + ip-addresses: "127.0.0.3,172.20.240.3" + dns-names: "localhost,eventstore-node2.localhost.com" From 1e621e4aa01d63d1b9a4fc59ed5124128c8869b9 Mon Sep 17 00:00:00 2001 From: Joseph Cummings Date: Tue, 19 Mar 2024 11:43:18 +0000 Subject: [PATCH 2/4] Fixup and rebase onto master --- certificates/common.go | 6 +- certificates/common_test.go | 3 + certificates/create_ca.go | 3 +- certificates/create_ca_test.go | 14 +- certificates/create_certs_test.go | 20 ++- certificates/create_node.go | 2 +- certificates/create_node_test.go | 28 +--- certificates/create_user.go | 101 ++++++--------- certificates/create_user_test.go | 209 +++++++++++++++++++++--------- main.go | 11 +- 10 files changed, 226 insertions(+), 171 deletions(-) diff --git a/certificates/common.go b/certificates/common.go index 625f22a..fb0587d 100644 --- a/certificates/common.go +++ b/certificates/common.go @@ -20,7 +20,7 @@ const ( OutDirFlagUsage = "The output directory" DayFlagUsage = "the validity period of the certificate in days" CaKeyFlagUsage = "the path to the CA key file" - CaPathFlagUsage = "the path to the CA certificate file" + CaCertFlagUsage = "the path to the CA certificate file" ) const defaultKeySize = 2048 @@ -69,12 +69,12 @@ func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem * err := writeFileWithDir(certFile, certPem.Bytes(), 0444) if err != nil { - return fmt.Errorf("error writing CA certificate to %s: %s", certFile, err.Error()) + return fmt.Errorf("error writing certificate to %s: %s", certFile, err.Error()) } err = writeFileWithDir(keyFile, privateKeyPem.Bytes(), 0400) if err != nil { - return fmt.Errorf("error writing CA's private key to %s: %s", keyFile, err.Error()) + return fmt.Errorf("error writing private key to %s: %s", keyFile, err.Error()) } return nil diff --git a/certificates/common_test.go b/certificates/common_test.go index 76f21e3..c70d87a 100644 --- a/certificates/common_test.go +++ b/certificates/common_test.go @@ -22,6 +22,9 @@ func extractErrors(errorMessage string) []string { multiLinePattern := regexp.MustCompile(`\* (.+)`) multiLineMatches := multiLinePattern.FindAllStringSubmatch(errorMessage, -1) + ansiCodePattern := regexp.MustCompile(`\x1b\[[0-9;]*m`) + errorMessage = ansiCodePattern.ReplaceAllString(errorMessage, "") + if len(multiLineMatches) > 0 { for _, match := range multiLineMatches { if len(match) > 1 { diff --git a/certificates/create_ca.go b/certificates/create_ca.go index 867f410..52287cc 100644 --- a/certificates/create_ca.go +++ b/certificates/create_ca.go @@ -37,7 +37,7 @@ func NewCreateCA(ui cli.Ui) *CreateCA { c.Flags = flag.NewFlagSet("create_ca", flag.ContinueOnError) c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) c.Flags.StringVar(&c.Config.OutputDir, "out", "./ca", OutDirFlagUsage) - c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "", CaPathFlagUsage) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "", CaCertFlagUsage) c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "", CaKeyFlagUsage) c.Flags.StringVar(&c.Config.Name, "name", "ca", NameFlagUsage) c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) @@ -196,6 +196,7 @@ func (c *CreateCA) Help() string { c.Flags.PrintDefaults() return helpText.String() } + func (c *CreateCA) Synopsis() string { return "Generate a root/intermediate CA TLS certificate to be used with EventStoreDB" } diff --git a/certificates/create_ca_test.go b/certificates/create_ca_test.go index 81e0417..dae4561 100644 --- a/certificates/create_ca_test.go +++ b/certificates/create_ca_test.go @@ -24,7 +24,7 @@ func TestCreateCACertificate_NominalCase_ShouldSucceed(t *testing.T) { t.Parallel() - cleanup, tempDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{ + cleanup, tempDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{ OutputDir: "./ca", }) defer cleanup() @@ -50,7 +50,7 @@ func TestCreateCACertificate_DifferentOut_ShouldSucceed(t *testing.T) { t.Parallel() - cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) defer cleanup() args := []string{"-out", filepath.Join(tempCaDir, "my-custom-dir")} @@ -69,7 +69,7 @@ func TestCreateCACertificate_WithNameFlag_ShouldCreateNamedCertificates(t *testi t.Parallel() - cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) defer cleanup() args := []string{"-out", tempCaDir, "-name", "my-custom-name"} @@ -89,7 +89,7 @@ func TestCreateCACertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { t.Parallel() - cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + cleanup, tempCaDir, _, _, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) defer cleanup() // Create a CA certificate @@ -119,7 +119,7 @@ func TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail(t *test t.Parallel() - cleanup, tempCaDir, _, errorBuffer, createCa := setupCreateCaTestEnvironment(t, &TestEnvParams{}) + cleanup, tempCaDir, _, errorBuffer, createCa := setupCreateCaTestEnvironment(t, &TestCreateCAParams{}) defer cleanup() args := []string{ @@ -135,11 +135,11 @@ func TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail(t *test assert.Equal(t, "error reading file: open unknown: no such file or directory", errors[0]) } -type TestEnvParams struct { +type TestCreateCAParams struct { OutputDir string } -func setupCreateCaTestEnvironment(t *testing.T, params *TestEnvParams) (cleanupFunc func(), tempDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCa *CreateCA) { +func setupCreateCaTestEnvironment(t *testing.T, params *TestCreateCAParams) (cleanupFunc func(), tempDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCa *CreateCA) { tempDir = params.OutputDir if tempDir == "" { diff --git a/certificates/create_certs_test.go b/certificates/create_certs_test.go index 2c45b4e..4b6ed83 100644 --- a/certificates/create_certs_test.go +++ b/certificates/create_certs_test.go @@ -30,9 +30,7 @@ func TestCreateCertificates_ValidConfigFile_ShouldSucceed(t *testing.T) { certsFileWithName := "certs.yml" // Create a certs.yml file - createConfigFile(t, tempCertsDir, certsFileWithName, certsFile, tempCertsDir) - - fmt.Println(tempCertsDir, certsFileWithName) + createConfigFile(t, tempCertsDir, certsFileWithName, validCertificatesYaml, tempCertsDir) args := []string{ "-config-file", filepath.Join(tempCertsDir, certsFileWithName), @@ -64,7 +62,7 @@ func TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail(t *t cleanup, tempCertsDir, _, errorBuffer, createCerts := setupCertificateTestEnvironment(t) defer cleanup() - createConfigFile(t, tempCertsDir, "certs.yml", certsFile, tempCertsDir) + createConfigFile(t, tempCertsDir, "certs.yml", validCertificatesYaml, tempCertsDir) args := []string{ "-config-file", tempCertsDir + "/certs.yml", @@ -105,7 +103,7 @@ func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t certsFileWithName := "certs.yml" // Create a certs.yml file - createConfigFile(t, tempCertsDir, certsFileWithName, certsFile, tempCertsDir) + createConfigFile(t, tempCertsDir, certsFileWithName, validCertificatesYaml, tempCertsDir) args := []string{ "-config-file", filepath.Join(tempCertsDir, certsFileWithName), @@ -172,7 +170,7 @@ func TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertific certsFileName := "certs-with-name.yml" - createConfigFile(t, tempCertsDir, certsFileName, certsFileWithName, tempCertsDir) + createConfigFile(t, tempCertsDir, certsFileName, certificatesYamlWithOverrideName, tempCertsDir) args := []string{ "-config-file", filepath.Join(tempCertsDir, certsFileName), @@ -205,7 +203,7 @@ func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T certsFileName := "certs.yml" - createConfigFile(t, tempCertsDir, certsFileName, certsFileWithInvalidPath, tempCertsDir) + createConfigFile(t, tempCertsDir, certsFileName, certificatesYamlWithInvalidPath, tempCertsDir) args := []string{ "-config-file", filepath.Join(tempCertsDir, certsFileName), @@ -219,7 +217,7 @@ func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T assert.Equal(t, 1, len(errors), "Expected 1 error") expectedErrorMessage := fmt.Sprintf("error reading file: open %s: no such file or directory", filepath.Join(tempCertsDir, "invalid_root_ca", "ca.crt")) - assert.Equal(t, errors[0], expectedErrorMessage) + assert.Equal(t, expectedErrorMessage, errors[0]) // The root CA will be created assert.DirExists(t, filepath.Join(tempCertsDir, "root_ca")) @@ -230,7 +228,7 @@ func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T } // Valid certificate file -var certsFile = `certificates: +var validCertificatesYaml = `certificates: ca-certs: - out: "./root_ca" - out: "./intermediate_ca" @@ -255,7 +253,7 @@ var certsFile = `certificates: dns-names: "localhost,eventstore-node2.localhost.com"` // Invalid path defined at ca-certificate in the config -var certsFileWithInvalidPath = `certificates: +var certificatesYamlWithInvalidPath = `certificates: ca-certs: - out: "./root_ca" - out: "./intermediate_ca" @@ -270,7 +268,7 @@ var certsFileWithInvalidPath = `certificates: dns-names: "localhost,eventstore-node1.localhost.com"` // Each certificate have a name parameter -var certsFileWithName = `certificates: +var certificatesYamlWithOverrideName = `certificates: ca-certs: - out: "./custom_root" name: "custom_root" diff --git a/certificates/create_node.go b/certificates/create_node.go index 5c04559..bc52e80 100644 --- a/certificates/create_node.go +++ b/certificates/create_node.go @@ -42,7 +42,7 @@ func NewCreateNode(ui cli.Ui) *CreateNode { c := &CreateNode{Ui: ui} c.Flags = flag.NewFlagSet("create_node", flag.ContinueOnError) - c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaPathFlagUsage) + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaCertFlagUsage) c.Flags.StringVar(&c.Config.CommonName, "common-name", "eventstoredb-node", "the certificate subject common name") c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) c.Flags.StringVar(&c.Config.IPAddresses, "ip-addresses", "", "comma-separated list of IP addresses of the node") diff --git a/certificates/create_node_test.go b/certificates/create_node_test.go index fa05404..f47f0d1 100644 --- a/certificates/create_node_test.go +++ b/certificates/create_node_test.go @@ -13,14 +13,12 @@ import ( ) func TestCreateNodeCertificate(t *testing.T) { - t.Run("TestCreateNodeCertificate_NoParams_ShouldFail", TestCreateNodeCertificate_NoParams_ShouldFail) + t.Run("TestCreateNodeCertificate_WithoutParams_ShouldFail", TestCreateNodeCertificate_WithoutParams_ShouldFail) t.Run("TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed) t.Run("TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate", TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate) } -func TestCreateNodeCertificate_NoParams_ShouldFail(t *testing.T) { - // Create a node certificate with no params - +func TestCreateNodeCertificate_WithoutParams_ShouldFail(t *testing.T) { t.Parallel() cleanup, _, _, _, errorBuffer, createNode := setupCreateNodeTestEnvironment(t) @@ -33,7 +31,6 @@ func TestCreateNodeCertificate_NoParams_ShouldFail(t *testing.T) { errors := extractErrors(errorBuffer.String()) assert.Equal(t, 1, len(errors)) assert.Equal(t, "at least one IP address or DNS name needs to be specified with --ip-addresses or --dns-names", errors[0]) - assert.Equal(t, 1, result) } func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) { @@ -59,11 +56,9 @@ func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) cert, err := readCertificateFromFile(filepath.Join(tempNodeDir, "node.crt")) assert.NoError(t, err, "Failed to read and parse certificate file") - // The certificate should be valid for 1 year expectedNotAfter := time.Now().AddDate(1, 0, 0) assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") - // Now we verify if the certificate is signed by the provided root CA caCert, err := readCertificateFromFile(filepath.Join(tempCaDir, "ca.crt")) assert.NoError(t, err, "Failed to read and parse CA certificate file") @@ -75,8 +70,6 @@ func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) } func TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { - // Create a node certificate with the force flag - t.Parallel() cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) @@ -91,25 +84,16 @@ func TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { } result := createNode.Run(args) - assert.Equal(t, 0, result, "The 'create-node' operation without the --force flag should succeed the first time") + originalNodeCert, originalNodeKey := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") - assert.FileExists(t, filepath.Join(tempNodeDir, "node.crt"), "Node certificate should exist") - assert.FileExists(t, filepath.Join(tempNodeDir, "node.key"), "Node key should exist") - - // Read the content of the key and crt files - originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") - - // Create the node certificate again with the force flag updatedArgs := append(args, "-force") - result = createNode.Run(updatedArgs) assert.Equal(t, 0, result, "The 'create-node' should override the existing certificate with the --force flag") - // Read the content of the key and crt files again - newCaCert, newKeyCert := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") + newNodeCert, newNodeKey := readAndDecodeCertificateAndKey(t, tempNodeDir, "node") - assert.NotEqual(t, originalCaCert, newCaCert, "The CA certificate should be different") - assert.NotEqual(t, originalKeyCert, newKeyCert, "The CA key should be different") + assert.NotEqual(t, originalNodeCert, newNodeCert, "The Node certificate should be different") + assert.NotEqual(t, originalNodeKey, newNodeKey, "The Node key should be different") } func setupCreateNodeTestEnvironment(t *testing.T) (cleanupFunc func(), tempNodeDir, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createNode *CreateNode) { diff --git a/certificates/create_user.go b/certificates/create_user.go index 903e915..cd6a16e 100644 --- a/certificates/create_user.go +++ b/certificates/create_user.go @@ -10,9 +10,6 @@ import ( "errors" "flag" "fmt" - "path" - "strings" - "text/tabwriter" "time" multierror "github.com/hashicorp/go-multierror" @@ -20,7 +17,9 @@ import ( ) type CreateUser struct { - Ui cli.Ui + Ui cli.Ui + Config CreateUserArguments + Flags *flag.FlagSet } type CreateUserArguments struct { @@ -32,38 +31,43 @@ type CreateUserArguments struct { Force bool `yaml:"force"` } +func NewCreateUser(ui cli.Ui) *CreateUser { + c := &CreateUser{Ui: ui} + + c.Flags = flag.NewFlagSet("create_user", flag.ContinueOnError) + c.Flags.Usage = func() { c.Ui.Info(c.Help()) } + c.Flags.StringVar(&c.Config.Username, "username", "", "the EventStoreDB user") + c.Flags.StringVar(&c.Config.CACertificatePath, "ca-certificate", "./ca/ca.crt", CaCertFlagUsage) + c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) + c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) + c.Flags.StringVar(&c.Config.OutputDir, "out", "", OutDirFlagUsage) + c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) + + return c +} + func (c *CreateUser) Run(args []string) int { - var config CreateUserArguments - - flags := flag.NewFlagSet("create_user", flag.ContinueOnError) - flags.Usage = func() { c.Ui.Info(c.Help()) } - flags.StringVar(&config.Username, "username", "", "the EventStoreDB user") - flags.StringVar(&config.CACertificatePath, "ca-certificate", "./ca/ca.crt", "the path to the CA certificate file") - flags.StringVar(&config.CAKeyPath, "ca-key", "./ca/ca.key", "the path to the CA key file") - flags.IntVar(&config.Days, "days", 0, "the validity period of the certificate in days") - flags.StringVar(&config.OutputDir, "out", "", "The output directory") - flags.BoolVar(&config.Force, "force", false, forceOption) - - if err := flags.Parse(args); err != nil { + if err := c.Flags.Parse(args); err != nil { + c.Ui.Error(err.Error()) return 1 } validationErrors := new(multierror.Error) - if len(config.Username) == 0 { - multierror.Append(validationErrors, errors.New("username is a required field")) + if len(c.Config.Username) == 0 { + _ = multierror.Append(validationErrors, errors.New("username is a required field")) } - if len(config.CACertificatePath) == 0 { - multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) + if len(c.Config.CACertificatePath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-certificate is a required field")) } - if len(config.CAKeyPath) == 0 { - multierror.Append(validationErrors, errors.New("ca-key is a required field")) + if len(c.Config.CAKeyPath) == 0 { + _ = multierror.Append(validationErrors, errors.New("ca-key is a required field")) } - if config.Days < 0 { - multierror.Append(validationErrors, errors.New("days must be positive")) + if c.Config.Days < 0 { + _ = multierror.Append(validationErrors, errors.New("days must be positive")) } if validationErrors.ErrorOrNil() != nil { @@ -71,34 +75,29 @@ func (c *CreateUser) Run(args []string) int { return 1 } - caCert, err := readCertificateFromFile(config.CACertificatePath) + caCert, err := readCertificateFromFile(c.Config.CACertificatePath) if err != nil { c.Ui.Error(err.Error()) return 1 } - caKey, err := readRSAKeyFromFile(config.CAKeyPath) + caKey, err := readRSAKeyFromFile(c.Config.CAKeyPath) if err != nil { err := fmt.Errorf("error: %s. please note that only RSA keys are currently supported", err.Error()) c.Ui.Error(err.Error()) return 1 } - outputDir := config.OutputDir - outputBaseFileName := "user-" + config.Username + outputDir := c.Config.OutputDir + outputBaseFileName := "user-" + c.Config.Username if len(outputDir) == 0 { outputDir = outputBaseFileName } - // check if user certificates already exist - if fileExists(path.Join(outputDir, outputBaseFileName+".crt"), config.Force) { - c.Ui.Error(ErrFileExists) - return 1 - } - - if fileExists(path.Join(outputDir, outputBaseFileName+".key"), config.Force) { - c.Ui.Error(ErrFileExists) + certErr := checkCertificatesLocationWithForce(outputDir, outputBaseFileName, c.Config.Force) + if certErr != nil { + c.Ui.Error(certErr.Error()) return 1 } @@ -106,12 +105,12 @@ func (c *CreateUser) Run(args []string) int { years := 1 days := 0 - if config.Days != 0 { - days = config.Days + if c.Config.Days != 0 { + days = c.Config.Days years = 0 } - err = generateUserCertificate(config.Username, outputBaseFileName, caCert, caKey, years, days, outputDir, config.Force) + err = generateUserCertificate(c.Config.Username, outputBaseFileName, caCert, caKey, years, days, outputDir, c.Config.Force) if err != nil { c.Ui.Error(err.Error()) return 1 @@ -156,7 +155,7 @@ func generateUserCertificate(username string, outputBaseFileName string, caCert } privateKeyPem := new(bytes.Buffer) - pem.Encode(privateKeyPem, &pem.Block{ + err = pem.Encode(privateKeyPem, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) @@ -184,24 +183,10 @@ func generateUserCertificate(username string, outputBaseFileName string, caCert } func (c *CreateUser) Help() string { - var buffer bytes.Buffer - - w := tabwriter.NewWriter(&buffer, 0, 0, 2, ' ', 0) - - fmt.Fprintln(w, "Usage: create_user [options]") - fmt.Fprintln(w, c.Synopsis()) - fmt.Fprintln(w, "Options:") - - writeHelpOption(w, "username", "The name of the EventStoreDB user to generate a certificate for.") - writeHelpOption(w, "ca-certificate", "The path to the CA certificate file (default: ./ca/ca.crt).") - writeHelpOption(w, "ca-key", "The path to the CA key file (default: ./ca/ca.key).") - writeHelpOption(w, "days", "The validity period of the certificates in days (default: 1 year).") - writeHelpOption(w, "out", "The output directory (default: ./user-).") - writeHelpOption(w, "force", forceOption) - - w.Flush() - - return strings.TrimSpace(buffer.String()) + var helpText bytes.Buffer + c.Flags.SetOutput(&helpText) + c.Flags.PrintDefaults() + return helpText.String() } func (c *CreateUser) Synopsis() string { diff --git a/certificates/create_user_test.go b/certificates/create_user_test.go index f015fdc..f635bfe 100644 --- a/certificates/create_user_test.go +++ b/certificates/create_user_test.go @@ -1,86 +1,171 @@ package certificates import ( + "bytes" "crypto/x509" - "path" + "github.com/mitchellh/cli" + "os" + "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" ) -func setupTestEnvironmentForUserTests(t *testing.T) (years int, days int, username string, userCertFileName string, outputDirCa string, outputDirUser string) { - years = 1 - days = 0 - username = "bob" - userCertFileName = "user-" + username - outputDirCa = "./ca" - outputDirUser = "./" + userCertFileName +func TestCreateUserCertificate(t *testing.T) { + t.Run("TestCreateUserCertificate_WithoutParams_ShouldFail", TestCreateUserCertificate_WithoutParams_ShouldFail) + t.Run("TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed) + t.Run("TestCreateUserCertificate_WithNegativeDays_ShouldFail", TestCreateUserCertificate_WithNegativeDays_ShouldFail) + t.Run("TestCreateUserCertificate_WithForceFlag_ShouldRegenerate", TestCreateUserCertificate_WithForceFlag_ShouldRegenerate) +} + +func TestCreateUserCertificate_WithoutParams_ShouldFail(t *testing.T) { + t.Parallel() + + cleanup, _, _, _, errorBuffer, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + var args []string + result := createUser.Run(args) - cleanupDirsForTest(t, outputDirCa, outputDirUser) - return + assert.Equal(t, 1, result, "The 'create-user' operation should fail due to the absence of required parameters.") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "username is a required field", errors[0]) } -func TestGenerateUserCertificate(t *testing.T) { +func TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) { + t.Parallel() + + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + username := "ouro" + args := []string{ + "-username", username, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } + + if result := createUser.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } - t.Run("nominal-case", func(t *testing.T) { - years, days, username, userCertFileName, outputDirCa, outputDirUser := setupTestEnvironmentForUserTests(t) + userFmt := "user-" + username + userCertPath := filepath.Join(tempUserDir, userFmt+".crt") + userKeyPath := filepath.Join(tempUserDir, userFmt+".key") - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) + assert.FileExists(t, userCertPath, "User certificate should exist") + assert.FileExists(t, userKeyPath, "User key should exist") - err := generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.NoError(t, err) + cert, err := readCertificateFromFile(userCertPath) + assert.NoError(t, err, "Failed to read and parse certificate file") - userCertPath := path.Join(outputDirUser, userCertFileName+".crt") - userKeyPath := path.Join(outputDirUser, userCertFileName+".key") - assertFilesExist(t, userCertPath, userKeyPath) + expectedNotAfter := time.Now().AddDate(1, 0, 0) + assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") - userCertificate, _ := readCertificateFromFile(userCertPath) + caCert, err := readCertificateFromFile(filepath.Join(tempCaDir, "ca.crt")) + assert.NoError(t, err, "Failed to read and parse CA certificate file") - // verify the subject - assert.Equal(t, "CN="+username, userCertificate.Subject.String()) + roots := x509.NewCertPool() + roots.AddCert(caCert) - // verify the issuer - assert.Equal(t, caCertificate.Issuer.String(), userCertificate.Issuer.String()) + _, err = cert.Verify(x509.VerifyOptions{Roots: roots, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}}) + assert.NoError(t, err, "User certificate should be signed by the provided root CA") +} + +func TestCreateUserCertificate_WithNegativeDays_ShouldFail(t *testing.T) { + t.Parallel() + + cleanup, _, tempCaDir, _, errorBuffer, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + args := []string{ + "-username", "ouro", + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-days", "-1", + } + result := createUser.Run(args) + + assert.Equal(t, 1, result, "The 'create-user' operation should fail when days is negative.") + + errors := extractErrors(errorBuffer.String()) + assert.Equal(t, 1, len(errors)) + assert.Equal(t, "days must be positive", errors[0]) +} - // verify the EKUs - assert.Equal(t, 1, len(userCertificate.ExtKeyUsage)) - assert.Equal(t, x509.ExtKeyUsageClientAuth, userCertificate.ExtKeyUsage[0]) - assert.Equal(t, 0, len(userCertificate.UnknownExtKeyUsage)) +func TestCreateUserCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { + t.Parallel() + + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + username := "ouro" + args := []string{ + "-username", username, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } + + result := createUser.Run(args) + + userFmt := "user-" + username + originalUserCert, originalUserKey := readAndDecodeCertificateAndKey(t, tempUserDir, userFmt) + + updatedArgs := append(args, "-force") + result = createUser.Run(updatedArgs) + assert.Equal(t, 0, result, "The 'create-user' should override the existing certificate with the --force flag") + + newUserCert, newUserKey := readAndDecodeCertificateAndKey(t, tempUserDir, userFmt) + + assert.NotEqual(t, originalUserCert, newUserCert, "The User certificate should be different") + assert.NotEqual(t, originalUserKey, newUserKey, "The User key should be different") +} + +func setupCreateUserTestEnvironment(t *testing.T) (cleanupFunc func(), tempUserDir string, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createUser *CreateUser) { + var err error + + tempUserDir, err = os.MkdirTemp(os.TempDir(), "user-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + tempCaDir, err = os.MkdirTemp(os.TempDir(), "ca-*") + if err != nil { + t.Fatalf("Failed to create temp directory: %s", err) + } + + outputBuffer = new(bytes.Buffer) + errorBuffer = new(bytes.Buffer) + + createUser = NewCreateUser(&cli.BasicUi{ + Writer: outputBuffer, + ErrorWriter: errorBuffer, }) - t.Run("force-flag", func(t *testing.T) { - years, days, username, userCertFileName, outputDirCa, outputDirUser := setupTestEnvironmentForUserTests(t) - - caCertificate, caPrivateKey := generateAndAssertCACert(t, years, days, outputDirCa, false) - - err := generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.NoError(t, err) - - userCertPath := path.Join(outputDirUser, userCertFileName+".crt") - userKeyPath := path.Join(outputDirUser, userCertFileName+".key") - assertFilesExist(t, userCertPath, userKeyPath) - - userCertificate, _ := readCertificateFromFile(userCertPath) - userCertificateKey, _ := readRSAKeyFromFile(userKeyPath) - - // try to generate again without force - err = generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, false) - assert.Error(t, err) - userCertificateAfter, err := readCertificateFromFile(userCertPath) - assert.NoError(t, err) - userCertificateKeyAfter, err := readRSAKeyFromFile(userKeyPath) - assert.NoError(t, err) - assert.Equal(t, userCertificate, userCertificateAfter, "Expected user certificate to be the same") - assert.Equal(t, userCertificateKey, userCertificateKeyAfter, "Expected user key to be the same") - - // try to generate again with force - err = generateUserCertificate(username, userCertFileName, caCertificate, caPrivateKey, years, days, outputDirUser, true) - assert.NoError(t, err) - userCertificateAfterWithForce, err := readCertificateFromFile(userCertPath) - assert.NoError(t, err) - userCertificateKeyAfterWithForce, err := readRSAKeyFromFile(userKeyPath) - assert.NoError(t, err) - assert.NotEqual(t, userCertificate, userCertificateAfterWithForce, "Expected user certificate to be different") - assert.NotEqual(t, userCertificateKey, userCertificateKeyAfterWithForce, "Expected user key to be different") + // We need to create a root CA file to be able to create a user certificate + createCa := NewCreateCA(&cli.BasicUi{ + Writer: new(bytes.Buffer), + ErrorWriter: new(bytes.Buffer), }) + + args := []string{"-out", tempCaDir} + if result := createCa.Run(args); result != 0 { + t.Fatalf("Expected 0, got %d", result) + } + + cleanupFunc = func() { + if err := os.RemoveAll(tempUserDir); err != nil { + t.Logf("Failed to remove temp user directory (%s): %s", tempUserDir, err) + } + if err := os.RemoveAll(tempCaDir); err != nil { + t.Logf("Failed to remove temp ca directory (%s): %s", tempCaDir, err) + } + } + + return cleanupFunc, tempUserDir, tempCaDir, outputBuffer, errorBuffer, createUser } diff --git a/main.go b/main.go index 105e867..b9c5d0c 100644 --- a/main.go +++ b/main.go @@ -62,12 +62,11 @@ func main() { ), nil }, "create-user": func() (cli.Command, error) { - return &certificates.CreateUser{ - Ui: &cli.ColoredUi{ - Ui: ui, - OutputColor: cli.UiColorBlue, - }, - }, nil + return certificates.NewCreateUser(&cli.ColoredUi{ + Ui: ui, + OutputColor: cli.UiColorBlue, + }, + ), nil }, } c.HelpFunc = createGeneralHelpFunc(appName, flags) From 0bc87a887c4022dc1c7fa593b96a4b0385d0aa36 Mon Sep 17 00:00:00 2001 From: William Chong Date: Fri, 22 Mar 2024 09:08:15 +0400 Subject: [PATCH 3/4] Fixup * Add more create node test * Fix typo * Make sure tests runs on Windows --- .github/workflows/ci.yml | 12 +++++++ certificates/common.go | 4 +-- certificates/create_ca_test.go | 9 ++--- certificates/create_certs_test.go | 6 ++-- certificates/create_node.go | 3 ++ certificates/create_node_test.go | 59 +++++++++++++++++++++++++++++++ 6 files changed, 84 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d36436b..db9f113 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,3 +89,15 @@ jobs: tags: ghcr.io/${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} platforms: linux/amd64,linux/arm64 + + test-windows: + runs-on: windows-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.21 + - name: Run Tests + run: go test ./... diff --git a/certificates/common.go b/certificates/common.go index fb0587d..ab5e054 100644 --- a/certificates/common.go +++ b/certificates/common.go @@ -51,8 +51,8 @@ func writeFileWithDir(filePath string, data []byte, perm os.FileMode) error { } func writeCertAndKey(outputDir string, fileName string, certPem, privateKeyPem *bytes.Buffer, force bool) error { - certFile := filepath.Join(outputDir, fileName+".crt") - keyFile := filepath.Join(outputDir, fileName+".key") + certFile := filepath.ToSlash(fmt.Sprintf("%s/%s.crt", outputDir, fileName)) + keyFile := filepath.ToSlash(fmt.Sprintf("%s/%s.key", outputDir, fileName)) if force { if _, err := os.Stat(certFile); err == nil { diff --git a/certificates/create_ca_test.go b/certificates/create_ca_test.go index dae4561..35ccb23 100644 --- a/certificates/create_ca_test.go +++ b/certificates/create_ca_test.go @@ -32,7 +32,7 @@ func TestCreateCACertificate_NominalCase_ShouldSucceed(t *testing.T) { var args []string result := createCa.Run(args) - assert.Equal(t, 0, result, "creat-ca should pass without any additional parameters") + assert.Equal(t, 0, result, "create-ca should pass without any additional parameters") assert.FileExists(t, filepath.Join("./ca", "ca.crt"), "CA certificate should exist") assert.FileExists(t, filepath.Join("./ca", "ca.key"), "CA key should exist") @@ -41,7 +41,8 @@ func TestCreateCACertificate_NominalCase_ShouldSucceed(t *testing.T) { assert.NoError(t, err, "Failed to read and parse certificate file") // The certificate should be valid for 5 year - expectedNotAfter := time.Now().AddDate(5, 0, 0) + now := time.Now().Truncate(time.Second) + expectedNotAfter := now.AddDate(5, 0, 0) assert.WithinDuration(t, expectedNotAfter, cert.NotAfter, time.Second, "Certificate validity period does not match expected default") } @@ -128,11 +129,11 @@ func TestCreateIntermediateCertificate_WithoutRootCertificate_ShouldFail(t *test "-ca-key", "unknown", } result := createCa.Run(args) - assert.Equal(t, 1, result, "creat-ca should fail without a root certificate") + assert.Equal(t, 1, result, "create-ca should fail without a root certificate") errors := extractErrors(errorBuffer.String()) assert.Equal(t, 1, len(errors), "Expected 1 error") - assert.Equal(t, "error reading file: open unknown: no such file or directory", errors[0]) + assert.Contains(t, errors[0], "error reading file") } type TestCreateCAParams struct { diff --git a/certificates/create_certs_test.go b/certificates/create_certs_test.go index 4b6ed83..b05553e 100644 --- a/certificates/create_certs_test.go +++ b/certificates/create_certs_test.go @@ -216,8 +216,8 @@ func TestCreateCertificates_InvalidPathInConfig_ShouldFailWithError(t *testing.T assert.Equal(t, 1, len(errors), "Expected 1 error") - expectedErrorMessage := fmt.Sprintf("error reading file: open %s: no such file or directory", filepath.Join(tempCertsDir, "invalid_root_ca", "ca.crt")) - assert.Equal(t, expectedErrorMessage, errors[0]) + assert.Contains(t, errors[0], "error reading file") + assert.Contains(t, errors[0], filepath.ToSlash(fmt.Sprintf("%s/invalid_root_ca/ca.crt", tempCertsDir))) // The root CA will be created assert.DirExists(t, filepath.Join(tempCertsDir, "root_ca")) @@ -321,7 +321,7 @@ func setupCertificateTestEnvironment(t *testing.T) (cleanupFunc func(), tempCert } func createConfigFile(t *testing.T, dirPath string, fileName string, content string, newParentDir string) { - updatedContent := strings.ReplaceAll(content, "./", fmt.Sprintf("%s/", newParentDir)) + updatedContent := strings.ReplaceAll(content, "./", fmt.Sprintf("%s/", filepath.ToSlash(newParentDir))) filePath := filepath.Join(dirPath, fileName) diff --git a/certificates/create_node.go b/certificates/create_node.go index bc52e80..103a7e6 100644 --- a/certificates/create_node.go +++ b/certificates/create_node.go @@ -151,6 +151,9 @@ func (c *CreateNode) Run(args []string) int { c.Ui.Error(err.Error()) return 1 } + } + + if len(outputBaseFileName) == 0 { outputBaseFileName = outputDir } diff --git a/certificates/create_node_test.go b/certificates/create_node_test.go index f47f0d1..b8f2970 100644 --- a/certificates/create_node_test.go +++ b/certificates/create_node_test.go @@ -15,6 +15,8 @@ import ( func TestCreateNodeCertificate(t *testing.T) { t.Run("TestCreateNodeCertificate_WithoutParams_ShouldFail", TestCreateNodeCertificate_WithoutParams_ShouldFail) t.Run("TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed) + t.Run("TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate", TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate) + t.Run("TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate", TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate) t.Run("TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate", TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate) } @@ -69,6 +71,63 @@ func TestCreateNodeCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) assert.NoError(t, err, "Node certificate should be signed by the provided root CA") } +func TestCreateNodeCertificate_WithNameFlagAndOutput_ShouldCreateNamedCertificate(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-out", tempNodeDir, + "-ip-addresses", "127.0.0.1", + "-dns-names", "localhost", + "-name", "renamed", + } + + result := createNode.Run(args) + assert.Equal(t, 0, result, "The 'create-node' operation should succeed with the --name flag") + + assert.FileExists(t, filepath.Join(tempNodeDir, "renamed.crt"), "Renamed certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "renamed.key"), "Renamed key should exist") +} + +func TestCreateNodeCertificate_WithNameFlagWithoutOutput_ShouldCreateNamedCertificate(t *testing.T) { + t.Parallel() + + cleanup, tempNodeDir, tempCaDir, _, _, createNode := setupCreateNodeTestEnvironment(t) + defer cleanup() + + originalDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current directory: %s", err) + } + defer func(dir string) { + err := os.Chdir(dir) + if err != nil { + t.Fatalf("Failed to change to orignal directory: %s", err) + } + }(originalDir) + + if err := os.Chdir(tempNodeDir); err != nil { + t.Fatalf("Failed to change current directory: %s", err) + } + + args := []string{ + "-ca-certificate", fmt.Sprintf("%s/ca.crt", tempCaDir), + "-ca-key", fmt.Sprintf("%s/ca.key", tempCaDir), + "-ip-addresses", "127.0.0.1", + "-name", "renamed_without_output", + } + + result := createNode.Run(args) + assert.Equal(t, 0, result, "The 'create-node' operation should succeed with the --name flag") + + assert.FileExists(t, filepath.Join(tempNodeDir, "node1", "renamed_without_output.crt"), "Renamed certificate should exist") + assert.FileExists(t, filepath.Join(tempNodeDir, "node1", "renamed_without_output.key"), "Renamed key should exist") +} + func TestCreateNodeCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { t.Parallel() From 0415a77801835ed4087dea00dd4dba0f0b20a07d Mon Sep 17 00:00:00 2001 From: William Chong Date: Fri, 22 Mar 2024 09:58:10 +0400 Subject: [PATCH 4/4] Add support for user certs in create-certs command --- certificates/create_certs.go | 18 +++++++++++++++++- certificates/create_certs_test.go | 31 +++++++++++++++++++++++++++++-- certificates/create_user.go | 12 +++++++++--- certificates/create_user_test.go | 31 +++++++++++++++++++++++++++++++ references/certs.yml | 5 +++++ references/named_certs.yml | 6 ++++++ 6 files changed, 97 insertions(+), 6 deletions(-) diff --git a/certificates/create_certs.go b/certificates/create_certs.go index 56f4790..ccc6eea 100644 --- a/certificates/create_certs.go +++ b/certificates/create_certs.go @@ -27,6 +27,7 @@ type Config struct { Certificates struct { CaCerts []CreateCAArguments `yaml:"ca-certs"` Nodes []CreateNodeArguments `yaml:"node-certs"` + Users []CreateUserArguments `yaml:"user-certs"` } `yaml:"certificates"` } @@ -63,7 +64,7 @@ func (c *CreateCertificates) Run(args []string) int { return 1 } - if c.generateCaCerts(config, c.Config.Force) != 0 || c.generateNodes(config, c.Config.Force) != 0 { + if c.generateCaCerts(config, c.Config.Force) != 0 || c.generateNodes(config, c.Config.Force) != 0 || c.generateUsers(config, c.Config.Force) != 0 { return 1 } @@ -99,6 +100,21 @@ func (c *CreateCertificates) checkPaths(config Config, force bool) error { wg.Wait() return certError } + +func (c *CreateCertificates) generateUsers(config Config, force bool) int { + for _, user := range config.Certificates.Users { + user.Force = force + createUser := NewCreateUser(&cli.ColoredUi{ + Ui: c.Ui, + OutputColor: cli.UiColorBlue, + }) + if createUser.Run(toArguments(user)) != 0 { + return 1 + } + } + return 0 +} + func (c *CreateCertificates) generateNodes(config Config, force bool) int { for _, node := range config.Certificates.Nodes { node.Force = force diff --git a/certificates/create_certs_test.go b/certificates/create_certs_test.go index b05553e..c34f8b7 100644 --- a/certificates/create_certs_test.go +++ b/certificates/create_certs_test.go @@ -49,6 +49,9 @@ func TestCreateCertificates_ValidConfigFile_ShouldSucceed(t *testing.T) { assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.crt"), fmt.Sprintf("%s certificate should exist", node)) assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) } + + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") } func TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail(t *testing.T) { @@ -82,6 +85,9 @@ func TestCreateCertificates_ExistingCertificatesWithoutForceFlag_ShouldFail(t *t assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) } + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") + // Try to generate the certificates again and expect and error result = createCerts.Run(args) assert.Equal(t, 1, result, "The create-certs command should fail the second time it is run since the certificates already exist") @@ -123,9 +129,13 @@ func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t assert.FileExists(t, filepath.Join(tempCertsDir, node, "node.key"), fmt.Sprintf("%s certificate key should exist", node)) } + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "user-admin.key"), "User admin private key should exist") + // Read the content of the key and crt files generated from the config file originalCaCert, originalKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") originalIntermediateCaCert, originalIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + originalUserCert, originalUserCertKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "user-admin"), "user-admin") originalCerts := make(map[string][2]interface{}) @@ -145,6 +155,7 @@ func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t newRootCaCert, newRootCaKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "root_ca"), "ca") newIntermediateCaCert, newIntermediateKeyCert := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "intermediate_ca"), "ca") + newUserCert, newUserCertKey := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, "user-admin"), "user-admin") assert.NotEqual(t, originalCaCert, newRootCaCert, "Root CA certificate should be regenerated") assert.NotEqual(t, originalKeyCert, newRootCaKey, "Root CA key should be regenerated") @@ -152,6 +163,9 @@ func TestCreateCertificates_ForceFlagWithExistingCertificates_ShouldRegenerate(t assert.NotEqual(t, originalIntermediateCaCert, newIntermediateCaCert, "Intermediate CA certificate should be regenerated") assert.NotEqual(t, originalIntermediateKeyCert, newIntermediateKeyCert, "Intermediate CA key should be regenerated") + assert.NotEqual(t, originalUserCert, newUserCert, "User certificate should be regenerated") + assert.NotEqual(t, originalUserCertKey, newUserCertKey, "User certificate key should be regenerated") + for _, node := range nodes { newCAHash, newKeyHash := readAndDecodeCertificateAndKey(t, filepath.Join(tempCertsDir, node), "node") assert.NotEqual(t, originalCerts[node][0], newCAHash, fmt.Sprintf("%s certificate should be regenerated", node)) @@ -183,6 +197,8 @@ func TestCreateCertificates_ValidConfigWithCustomNames_ShouldCreateNamedCertific assert.FileExists(t, filepath.Join(tempCertsDir, "custom_root", "custom_root.key"), "Root CA key should exist") assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.crt"), "Intermediate certificate should exist") assert.FileExists(t, filepath.Join(tempCertsDir, "custom_intermediate", "custom_intermediate.key"), "Intermediate certificate key should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "renamed.crt"), "User admin certificate should exist") + assert.FileExists(t, filepath.Join(tempCertsDir, "user-admin", "renamed.key"), "Intermediate certificate key should exist") nodes := []string{"custom_node1", "custom_node2", "custom_node3"} for _, node := range nodes { @@ -250,7 +266,12 @@ var validCertificatesYaml = `certificates: ca-certificate: "./intermediate_ca/ca.crt" ca-key: "./intermediate_ca/ca.key" ip-addresses: "127.0.0.3,172.20.240.3" - dns-names: "localhost,eventstore-node2.localhost.com"` + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key"` // Invalid path defined at ca-certificate in the config var certificatesYamlWithInvalidPath = `certificates: @@ -295,7 +316,13 @@ var certificatesYamlWithOverrideName = `certificates: ca-certificate: "./custom_intermediate/custom_intermediate.crt" ca-key: "./custom_intermediate/custom_intermediate.key" ip-addresses: "127.0.0.3,172.20.240.3" - dns-names: "localhost,eventstore-node2.localhost.com"` + dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + name: "renamed" + ca-certificate: "./custom_root/custom_root.crt" + ca-key: "./custom_root/custom_root.key"` func setupCertificateTestEnvironment(t *testing.T) (cleanupFunc func(), tempCertsDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createCerts *CreateCertificates) { tempCertsDir, err := os.MkdirTemp(os.TempDir(), "certs-*") diff --git a/certificates/create_user.go b/certificates/create_user.go index cd6a16e..d4063bd 100644 --- a/certificates/create_user.go +++ b/certificates/create_user.go @@ -10,6 +10,7 @@ import ( "errors" "flag" "fmt" + "path/filepath" "time" multierror "github.com/hashicorp/go-multierror" @@ -28,6 +29,7 @@ type CreateUserArguments struct { CAKeyPath string `yaml:"ca-key"` Days int `yaml:"days"` OutputDir string `yaml:"out"` + Name string `yaml:"name"` Force bool `yaml:"force"` } @@ -41,6 +43,7 @@ func NewCreateUser(ui cli.Ui) *CreateUser { c.Flags.StringVar(&c.Config.CAKeyPath, "ca-key", "./ca/ca.key", CaKeyFlagUsage) c.Flags.IntVar(&c.Config.Days, "days", 0, DayFlagUsage) c.Flags.StringVar(&c.Config.OutputDir, "out", "", OutDirFlagUsage) + c.Flags.StringVar(&c.Config.Name, "name", "", NameFlagUsage) c.Flags.BoolVar(&c.Config.Force, "force", false, ForceFlagUsage) return c @@ -89,10 +92,13 @@ func (c *CreateUser) Run(args []string) int { } outputDir := c.Config.OutputDir - outputBaseFileName := "user-" + c.Config.Username + outputBaseFileName := c.Config.Name + if outputBaseFileName == "" { + outputBaseFileName = "user-" + c.Config.Username + } - if len(outputDir) == 0 { - outputDir = outputBaseFileName + if outputDir == "" { + outputDir = filepath.Dir(outputBaseFileName) } certErr := checkCertificatesLocationWithForce(outputDir, outputBaseFileName, c.Config.Force) diff --git a/certificates/create_user_test.go b/certificates/create_user_test.go index f635bfe..0cf1160 100644 --- a/certificates/create_user_test.go +++ b/certificates/create_user_test.go @@ -17,6 +17,7 @@ func TestCreateUserCertificate(t *testing.T) { t.Run("TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed", TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed) t.Run("TestCreateUserCertificate_WithNegativeDays_ShouldFail", TestCreateUserCertificate_WithNegativeDays_ShouldFail) t.Run("TestCreateUserCertificate_WithForceFlag_ShouldRegenerate", TestCreateUserCertificate_WithForceFlag_ShouldRegenerate) + t.Run("TestCreateUserCertificate_WithNameFlag_ShouldSucceed", TestCreateUserCertificate_WithNameFlag_ShouldSucceed) } func TestCreateUserCertificate_WithoutParams_ShouldFail(t *testing.T) { @@ -74,6 +75,7 @@ func TestCreateUserCertificate_WithAllRequiredParams_ShouldSucceed(t *testing.T) _, err = cert.Verify(x509.VerifyOptions{Roots: roots, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}}) assert.NoError(t, err, "User certificate should be signed by the provided root CA") + assert.Equal(t, username, cert.Subject.CommonName, "The common name of the certificate should be the same as the provided username") } func TestCreateUserCertificate_WithNegativeDays_ShouldFail(t *testing.T) { @@ -126,6 +128,35 @@ func TestCreateUserCertificate_WithForceFlag_ShouldRegenerate(t *testing.T) { assert.NotEqual(t, originalUserKey, newUserKey, "The User key should be different") } +func TestCreateUserCertificate_WithNameFlag_ShouldSucceed(t *testing.T) { + t.Parallel() + + cleanup, tempUserDir, tempCaDir, _, _, createUser := setupCreateUserTestEnvironment(t) + defer cleanup() + + username := "ouro" + name := "testing" + args := []string{ + "-username", username, + "-name", name, + "-ca-certificate", filepath.Join(tempCaDir, "ca.crt"), + "-ca-key", filepath.Join(tempCaDir, "ca.key"), + "-out", tempUserDir, + } + + result := createUser.Run(args) + + assert.Equal(t, 0, result, "The 'create-user' create the certificates with the provided name") + + assert.FileExists(t, filepath.Join(tempUserDir, name+".crt"), "User certificate should exist") + assert.FileExists(t, filepath.Join(tempUserDir, name+".key"), "User key should exist") + + cert, err := readCertificateFromFile(filepath.Join(tempUserDir, name+".crt")) + assert.NoError(t, err, "Failed to read and parse certificate file") + + assert.Equal(t, username, cert.Subject.CommonName, "The common name of the certificate should be the same as the provided username") +} + func setupCreateUserTestEnvironment(t *testing.T) (cleanupFunc func(), tempUserDir string, tempCaDir string, outputBuffer *bytes.Buffer, errorBuffer *bytes.Buffer, createUser *CreateUser) { var err error diff --git a/references/certs.yml b/references/certs.yml index 67732f7..8b19725 100644 --- a/references/certs.yml +++ b/references/certs.yml @@ -21,3 +21,8 @@ certificates: ca-key: "./intermediate_ca/ca.key" ip-addresses: "127.0.0.3,172.20.240.3" dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + ca-certificate: "./root_ca/ca.crt" + ca-key: "./root_ca/ca.key" diff --git a/references/named_certs.yml b/references/named_certs.yml index 97dae19..af06aa8 100644 --- a/references/named_certs.yml +++ b/references/named_certs.yml @@ -26,3 +26,9 @@ certificates: ca-key: "./intermediate_ca/intermediate.key" ip-addresses: "127.0.0.3,172.20.240.3" dns-names: "localhost,eventstore-node2.localhost.com" + user-certs: + - out: "./user-admin" + username: "admin" + name: "admin" + ca-certificate: "./root_ca/root.crt" + ca-key: "./root_ca/root.key" \ No newline at end of file