diff --git a/zfs/zfs.go b/zfs/zfs.go index 5daa37c0..fb401d8c 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -3,6 +3,7 @@ package zfs import ( "bufio" "bytes" + "cmp" "context" "encoding/json" "errors" @@ -13,14 +14,13 @@ import ( "os" "os/exec" "regexp" - "sort" + "slices" "strconv" "strings" "sync" "github.com/prometheus/client_golang/prometheus" - "github.com/dsh2dsh/zrepl/util/circlog" "github.com/dsh2dsh/zrepl/util/envconst" "github.com/dsh2dsh/zrepl/util/nodefault" zfsprop "github.com/dsh2dsh/zrepl/zfs/property" @@ -294,16 +294,15 @@ func absVersion(fs string, v *ZFSSendArgVersion) (full string, err error) { return fmt.Sprintf("%s%s", fs, v.RelName), nil } -func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { - if capacity < 0 { - panic(fmt.Sprintf("capacity must be non-negative, got %v", capacity)) - } - stdoutReader, stdoutWriter, err := os.Pipe() - if err != nil { - return nil, nil, err +func NewSendStream(cmd *zfscmd.Cmd, r io.ReadCloser, stderrBuf *bytes.Buffer, + cancel context.CancelFunc, +) *SendStream { + return &SendStream{ + cmd: cmd, + stdoutReader: r, + stderrBuf: stderrBuf, + cancel: cancel, } - trySetPipeCapacity(stdoutWriter, capacity) - return stdoutReader, stdoutWriter, nil } type sendStreamState int @@ -316,47 +315,47 @@ const ( type SendStream struct { cmd *zfscmd.Cmd stdoutReader io.ReadCloser // not *os.File for mocking during platformtest - stderrBuf *circlog.CircularLog + stderrBuf *bytes.Buffer + cancel context.CancelFunc - mtx sync.Mutex - state sendStreamState - exitErr *ZFSError + mtx sync.Mutex + state sendStreamState + exitErr *ZFSError + testMode bool } -func (s *SendStream) Read(p []byte) (n int, _ error) { +func (s *SendStream) Read(p []byte) (int, error) { s.mtx.Lock() defer s.mtx.Unlock() - switch s.state { - case sendStreamClosed: + if s.state == sendStreamClosed { return 0, os.ErrClosed + } else if s.state != sendStreamOpen { + panic("unreachable") + } - case sendStreamOpen: - n, readErr := s.stdoutReader.Read(p) - if readErr != nil { - debug("sendStream: read: readErr=%T %s", readErr, readErr) - if readErr == io.EOF { - // io.EOF must be bubbled up as is so that consumers can handle it properly. - return n, readErr - } - // Assume that the error is not retryable. - // Try to kill now so that we can return a nice *ZFSError with captured stderr. - // If the kill doesn't work, it doesn't matter because the caller must by contract call Close() anyways. - killErr := s.killAndWait() - debug("sendStream: read: killErr=%T %s", killErr, killErr) - if killErr == nil { - s.state = sendStreamClosed - return n, s.exitErr // return the nice error - } else { - // we remain open so that we retry - return n, readErr // return the normal error - } - } - return n, readErr + n, err := s.stdoutReader.Read(p) + if err == nil { + return n, err + } - default: - panic("unreachable") + debug("sendStream: read: err=%T %s", err, err) + if err == io.EOF { + // io.EOF must be bubbled up as is so that consumers can handle it + // properly. + return n, err + } else if !s.testMode { + s.cmd.Log().WithError(err).Info("failed read from pipe") + } + + // Assume that the error is not retryable. Try to kill now so that we can + // return a nice *ZFSError with captured stderr. If the kill doesn't work, + // it doesn't matter because the caller must by contract call Close() + // anyways. + if err := s.closeWait(); err != nil { + return n, err } + return n, err } func (s *SendStream) Close() error { @@ -364,112 +363,37 @@ func (s *SendStream) Close() error { s.mtx.Lock() defer s.mtx.Unlock() - switch s.state { - case sendStreamOpen: - err := s.killAndWait() - if err != nil { - return err - } else { - s.state = sendStreamClosed - return nil - } - case sendStreamClosed: + if s.state == sendStreamClosed { return os.ErrClosed - default: + } else if s.state != sendStreamOpen { panic("unreachable") } + return s.closeWait() } -// returns nil iff the child process is gone (has been successfully waited upon) -// in that case, s.exitErr is set -func (s *SendStream) killAndWait() error { - debug("sendStream: killAndWait enter") - defer debug("sendStream: killAndWait leave") - - // ensure this function is only called once - if s.state != sendStreamOpen { - panic(s.state) - } - - // send SIGKILL - // In an earlier version, we used the starting context.Context's cancel function - // for this, but in Go > 1.19, doing so will cause .Wait() to return the - // context cancel error instead of the *exec.ExitError. - err := s.cmd.Process().Kill() - if err != nil { - if err == os.ErrProcessDone { - // This can happen if - // (1) the process has already been .Wait()ed, or - // (2) some other goroutine cancels the context, likely further up - // the context tree. - // Case (1) can't happen to us because we only call - // this function in sendStreamOpen state. - // In Case (2), it's still our job to .Wait(), so, fallthrough. - } else { - return err - } - } - - // Close our read-end of the pipe. - // - // We must do this before .Wait() because in some (not all) versions/build configs of ZFS, - // `zfs send` uses a separate kernel thread (taskq) to write the send stream (function `dump_bytes`). - // The `zfs send` thread then waits uinterruptably for the taskq thread to finish the write. - // And signalling the `zfs send` thread doesn't propagate to the taskq thread. - // So we end up in a state where we .Wait() forever. - // (See https://github.com/openzfs/zfs/issues/12500 and - // https://github.com/zrepl/zrepl/issues/495#issuecomment-902530043) - // - // By closing our read end of the pipe before .Wait(), we unblock the taskq thread if there is any. - // If there is no separate taskq thread, the SIGKILL to `zfs end` would suffice and be most precise, - // but due to the circumstances above, there is no other portable & robust way. - // - // However, the fallout from closing the pipe is that (in non-taskq builds) `zfs sends` will get a SIGPIPE. - // And on Linux, that SIGPIPE appears to win over the previously issued SIGKILL. - // And thus, on Linux, the `zfs send` will be killed by the default SIGPIPE handler. - // We can observe this in the WaitStatus below. - // This behavior is slightly annoying because the *exec.ExitError's message ("signal: broken pipe") - // isn't as clear as ("signal: killed"). - // However, it seems like we just have to live with that. (covered by platformtest) - var closePipeErr error - if s.stdoutReader != nil { - closePipeErr = s.stdoutReader.Close() - if closePipeErr == nil { - // avoid double-closes in case waiting below doesn't work - // and someone attempts Close again - s.stdoutReader = nil - } else { - return closePipeErr - } - } +func (s *SendStream) closeWait() error { + defer s.cancel() + s.stdoutReader.Close() + s.state = sendStreamClosed - waitErr := s.cmd.Wait() - // distinguish between ExitError (which is actually a non-problem for us) - // vs failed wait syscall (for which we give upper layers the chance to retyr) - var exitErr *exec.ExitError - if waitErr != nil { - if ee, ok := waitErr.(*exec.ExitError); ok { - exitErr = ee - } else { - return waitErr - } + if err := s.wait(); err != nil { + debug("sendStream: wait: err=%T %s", err, err) + s.exitErr = NewZfsError(err, s.stderrBuf.Bytes()) + return s.exitErr } + return nil +} - // invariant: at this point, the child is gone and we cleaned up everything related to the SendStream - - if exitErr != nil { - // zfs send exited with an error or was killed by a signal. - s.exitErr = NewZfsError(exitErr, s.stderrBuf.Bytes()) - } else { - // zfs send exited successfully (we know that since waitErr was either nil or wasn't an *exec.ExitError) - s.exitErr = nil +func (s *SendStream) wait() error { + if s.testMode { + return s.cmd.WaitPipe() } - - return nil + return s.cmd.Wait() } -func (s *SendStream) TestOnly_ReplaceStdoutReader(f io.ReadCloser) (prev io.ReadCloser) { - prev = s.stdoutReader +func (s *SendStream) TestOnly_ReplaceStdoutReader(f io.ReadCloser, +) io.ReadCloser { + prev := s.stdoutReader s.stdoutReader = f return prev } @@ -887,9 +811,8 @@ func (a ZFSSendArgsUnvalidated) validateEncryptionFlagsCorrespondToResumeToken(c } } -var zfsSendStderrCaptureMaxSize = envconst.Int("ZREPL_ZFS_SEND_STDERR_MAX_CAPTURE_SIZE", 1<<15) - -var ErrEncryptedSendNotSupported = errors.New("raw sends which are required for encrypted zfs send are not supported") +var ErrEncryptedSendNotSupported = errors.New( + "raw sends which are required for encrypted zfs send are not supported") // if token != "", then send -t token is used // otherwise send [-i from] to is used @@ -900,12 +823,9 @@ var ErrEncryptedSendNotSupported = errors.New("raw sends which are required for func ZFSSend( ctx context.Context, sendArgs ZFSSendArgsValidated, pipeCmds ...[]string, ) (*SendStream, error) { - args := make([]string, 0) - args = append(args, "send") - - // pre-validation of sendArgs for plain ErrEncryptedSendNotSupported error - // we tie BackupProperties (send -b) and SendRaw (-w, same as with Encrypted) to this - // since these were released together. + // Pre-validation of sendArgs for plain ErrEncryptedSendNotSupported error. We + // tie BackupProperties (send -b) and SendRaw (-w, same as with Encrypted) to + // this since these were released together. if sendArgs.Encrypted.B { if encryptionSupported, err := EncryptionCLISupported(ctx); err != nil { return nil, fmt.Errorf("cannot determine CLI native encryption support: %w", err) @@ -918,44 +838,22 @@ func ZFSSend( if err != nil { return nil, err } + args := make([]string, 0, len(sargs)+1) + args = append(args, "send") args = append(args, sargs...) ctx, cancel := context.WithCancel(ctx) - - // setup stdout with an os.Pipe to control pipe buffer size - stdoutReader, stdoutWriter, err := pipeWithCapacityHint(getPipeCapacityHint("ZFS_SEND_PIPE_CAPACITY_HINT")) - if err != nil { - cancel() - return nil, err - } - stderrBuf := circlog.MustNewCircularLog(zfsSendStderrCaptureMaxSize) - cmd := zfscmd.New(ctx).WithPipeLen(len(pipeCmds)).WithCommand(ZfsBin, args) - pipeReader, err := cmd.Pipe(stdoutReader, stdoutWriter, stderrBuf, pipeCmds) + stderrBuf := new(bytes.Buffer) + pipeReader, err := cmd.PipeTo(pipeCmds, nil, stderrBuf) if err != nil { cancel() - stdoutWriter.Close() - stdoutReader.Close() return nil, err - } - - if err := cmd.Start(); err != nil { + } else if err := cmd.Start(); err != nil { cancel() - stdoutWriter.Close() - stdoutReader.Close() return nil, fmt.Errorf("cannot start zfs send command: %w", err) } - // close our writing-end of the pipe so that we don't wait for ourselves when reading from the reading end - stdoutWriter.Close() - - stream := &SendStream{ - cmd: cmd, - stdoutReader: pipeReader, - stderrBuf: stderrBuf, - } - _ = cancel // the SendStream.killAndWait() will kill the process - - return stream, nil + return NewSendStream(cmd, pipeReader, stderrBuf, cancel), nil } type DrySendType string @@ -1161,16 +1059,13 @@ func (opts RecvOptions) buildRecvFlags() []string { return args } -const RecvStderrBufSiz = 1 << 15 - func ZFSRecv( ctx context.Context, fs string, v *ZFSSendArgVersion, stream io.ReadCloser, opts RecvOptions, pipeCmds ...[]string, -) (err error) { +) error { if err := v.ValidateInMemory(fs); err != nil { return fmt.Errorf("invalid version: %w", err) - } - if !v.IsSnapshot() { + } else if !v.IsSnapshot() { return errors.New("must receive into a snapshot") } @@ -1180,34 +1075,11 @@ func ZFSRecv( } if opts.RollbackAndForceRecv { - // destroy all snapshots before `recv -F` because `recv -F` - // does not perform a rollback unless `send -R` was used (which we assume hasn't been the case) - snaps, err := ZFSListFilesystemVersions(ctx, fsdp, ListFilesystemVersionsOptions{ - Types: Snapshots, - }) - if _, ok := err.(*DatasetDoesNotExist); ok { - snaps = []FilesystemVersion{} - } else if err != nil { - return fmt.Errorf("cannot list versions for rollback for forced receive: %s", err) - } - sort.Slice(snaps, func(i, j int) bool { - return snaps[i].CreateTXG < snaps[j].CreateTXG - }) - // bookmarks are rolled back automatically - if len(snaps) > 0 { - // use rollback to efficiently destroy all but the earliest snapshot - // then destroy that earliest snapshot - // afterwards, `recv -F` will work - rollbackTarget := snaps[0] - rollbackTargetAbs := rollbackTarget.ToAbsPath(fsdp) - debug("recv: rollback to %q", rollbackTargetAbs) - if err := ZFSRollback(ctx, fsdp, rollbackTarget, "-r"); err != nil { - return fmt.Errorf("cannot rollback %s to %s for forced receive: %s", fsdp.ToString(), rollbackTarget, err) - } - debug("recv: destroy %q", rollbackTargetAbs) - if err := ZFSDestroy(ctx, rollbackTargetAbs); err != nil { - return fmt.Errorf("cannot destroy %s for forced receive: %s", rollbackTargetAbs, err) - } + // Destroy all snapshots before `recv -F` because `recv -F`. Does not + // perform a rollback unless `send -R` was used (which we assume hasn't been + // the case). + if err := zfsRollbackForceRecv(ctx, fsdp); err != nil { + return nil } } @@ -1217,78 +1089,91 @@ func ZFSRecv( } } - args := []string{"recv"} - args = append(args, opts.buildRecvFlags()...) + recvFlags := opts.buildRecvFlags() + args := make([]string, 0, len(recvFlags)+2) + args = append(args, "recv") + args = append(args, recvFlags...) args = append(args, fs) - ctx, cancelCmd := context.WithCancel(ctx) - defer cancelCmd() + ctx, cancel := context.WithCancel(ctx) + defer cancel() cmd := zfscmd.New(ctx).WithPipeLen(len(pipeCmds)).WithCommand(ZfsBin, args) - // TODO report bug upstream - // Setup an unused stdout buffer. - // Otherwise, ZoL v0.6.5.9-1 3.16.0-4-amd64 writes the following error to stderr and exits with code 1 - // cannot receive new filesystem stream: invalid backup stream - stdout := bytes.NewBuffer(make([]byte, 0, 1024)) - - stderr := bytes.NewBuffer(make([]byte, 0, RecvStderrBufSiz)) - - stdin, stdinWriter, err := pipeWithCapacityHint(getPipeCapacityHint("ZFS_RECV_PIPE_CAPACITY_HINT")) - if err != nil { - return err - } + // TODO report bug upstream Setup an unused stdout buffer. Otherwise, ZoL + // v0.6.5.9-1 3.16.0-4-amd64 writes the following error to stderr and exits + // with code 1 + // + // cannot receive new filesystem stream: invalid backup stream + var stdout, stderr bytes.Buffer - if err := cmd.PipeFrom(stdin, stdout, stderr, pipeCmds); err != nil { - stdinWriter.Close() - stdin.Close() + if err := cmd.PipeFrom(pipeCmds, stream, &stdout, &stderr); err != nil { return err - } - - if err = cmd.Start(); err != nil { - stdinWriter.Close() - stdin.Close() + } else if err = cmd.Start(); err != nil { return err } - stdin.Close() - defer stdinWriter.Close() pid := cmd.Process().Pid debug := func(format string, args ...interface{}) { debug("recv: pid=%v: %s", pid, fmt.Sprintf(format, args...)) } - debug("started") - _, copierErr := io.Copy(stdinWriter, stream) - debug("copierErr: %T %s", copierErr, copierErr) - stdinWriter.Close() + if err := cmd.Wait(); err != nil { + err = parseZfsRecvErr(ctx, err, stderr.Bytes()) + debug("wait err: %T %s", err, err) + // almost always more interesting info. NOTE: do not wrap! + return err + } + return nil +} - if copierErr != nil { - debug("killing zfs recv command after copierErr") - cancelCmd() +func zfsRollbackForceRecv(ctx context.Context, fsdp *DatasetPath) error { + snaps, err := ZFSListFilesystemVersions(ctx, fsdp, + ListFilesystemVersionsOptions{Types: Snapshots}) + if _, ok := err.(*DatasetDoesNotExist); ok { + snaps = []FilesystemVersion{} + } else if err != nil { + return fmt.Errorf( + "cannot list versions for rollback for forced receive: %s", err) + } else if len(snaps) == 0 { + return nil } - var waitErr error - if err = cmd.Wait(); err != nil { - if rtErr := tryRecvErrorWithResumeToken(ctx, stderr.String()); rtErr != nil { - waitErr = rtErr - } else if owErr := tryRecvDestroyOrOverwriteEncryptedErr(stderr.Bytes()); owErr != nil { - waitErr = owErr - } else if readErr := tryRecvCannotReadFromStreamErr(stderr.Bytes()); readErr != nil { - waitErr = readErr - } else { - waitErr = NewZfsError(err, stderr.Bytes()) - } + slices.SortFunc(snaps, func(a, b FilesystemVersion) int { + return cmp.Compare(a.CreateTXG, b.CreateTXG) + }) + + // bookmarks are rolled back automatically + // + // Use rollback to efficiently destroy all but the earliest snapshot, then + // destroy that earliest snapshot, afterwards, `recv -F` will work. + rollbackTarget := snaps[0] + rollbackTargetAbs := rollbackTarget.ToAbsPath(fsdp) + debug("recv: rollback to %q", rollbackTargetAbs) + if err := ZFSRollback(ctx, fsdp, rollbackTarget, "-r"); err != nil { + return fmt.Errorf( + "cannot rollback %s to %s for forced receive: %s", + fsdp.ToString(), rollbackTarget, err) } - debug("waitErr: %T %s", waitErr, waitErr) - if copierErr == nil && waitErr == nil { - return nil - } else if _, isReadErr := waitErr.(*RecvCannotReadFromStreamErr); isReadErr { - return copierErr // likely network error reading from stream - } else { - return waitErr // almost always more interesting info. NOTE: do not wrap! + debug("recv: destroy %q", rollbackTargetAbs) + if err := ZFSDestroy(ctx, rollbackTargetAbs); err != nil { + return fmt.Errorf( + "cannot destroy %s for forced receive: %s", + rollbackTargetAbs, err) + } + return nil +} + +func parseZfsRecvErr(ctx context.Context, err error, stderr []byte) error { + if err := tryRecvErrorWithResumeToken(ctx, string(stderr)); err != nil { + return err + } else if err := tryRecvDestroyOrOverwriteEncryptedErr(stderr); err != nil { + return err + } else if err := tryRecvCannotReadFromStreamErr(stderr); err != nil { + return err } + return NewZfsError(err, stderr) } type RecvFailedWithResumeTokenErr struct { diff --git a/zfs/zfs_pipe.go b/zfs/zfs_pipe.go deleted file mode 100644 index 9f903402..00000000 --- a/zfs/zfs_pipe.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build !linux -// +build !linux - -package zfs - -import ( - "os" - "sync" -) - -func getPipeCapacityHint(envvar string) int { - return 0 // not supported -} - -var zfsPipeCapacityNotSupported sync.Once - -func trySetPipeCapacity(p *os.File, capacity int) { - if debugEnabled && capacity != 0 { - zfsPipeCapacityNotSupported.Do(func() { - debug("trySetPipeCapacity error: OS does not support setting pipe capacity") - }) - } -} diff --git a/zfs/zfs_pipe_linux.go b/zfs/zfs_pipe_linux.go deleted file mode 100644 index 47d4e9f1..00000000 --- a/zfs/zfs_pipe_linux.go +++ /dev/null @@ -1,42 +0,0 @@ -package zfs - -import ( - "errors" - "fmt" - "os" - "strconv" - "strings" - - "golang.org/x/sys/unix" - - "github.com/dsh2dsh/zrepl/util/envconst" -) - -func getPipeCapacityHint(envvar string) int { - var capacity int64 = 1 << 25 - - // Work around a race condition in Linux >= 5.8 related to pipe resizing. - // https://github.com/zrepl/zrepl/issues/424#issuecomment-800370928 - // https://bugzilla.kernel.org/show_bug.cgi?id=212295 - if _, err := os.Stat("/proc/sys/fs/pipe-max-size"); err == nil { - if dat, err := os.ReadFile("/proc/sys/fs/pipe-max-size"); err == nil { - if capacity, err = strconv.ParseInt(strings.TrimSpace(string(dat)), 10, 64); err != nil { - capacity = 1 << 25 - } - } - } - - return int(envconst.Int64(envvar, capacity)) -} - -func trySetPipeCapacity(p *os.File, capacity int) { - res, err := unix.FcntlInt(p.Fd(), unix.F_SETPIPE_SZ, capacity) - if err != nil { - err = fmt.Errorf("cannot set pipe capacity to %v", capacity) - } else if res == -1 { - err = errors.New("cannot set pipe capacity: fcntl returned -1") - } - if debugEnabled && err != nil { - debug("trySetPipeCapacity error: %s\n", err) - } -} diff --git a/zfs/zfs_test.go b/zfs/zfs_test.go index c500bff1..e6a510b7 100644 --- a/zfs/zfs_test.go +++ b/zfs/zfs_test.go @@ -1,7 +1,9 @@ package zfs import ( + "bytes" "context" + "io" "strings" "testing" @@ -10,6 +12,7 @@ import ( "github.com/dsh2dsh/zrepl/util/nodefault" zfsprop "github.com/dsh2dsh/zrepl/zfs/property" + "github.com/dsh2dsh/zrepl/zfs/zfscmd" ) // FIXME make this a platformtest @@ -401,7 +404,6 @@ size 1500 func TestTryRecvDestroyOrOverwriteEncryptedErr(t *testing.T) { msg := "cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one\n" - assert.GreaterOrEqual(t, RecvStderrBufSiz, len(msg)) err := tryRecvDestroyOrOverwriteEncryptedErr([]byte(msg)) require.NotNil(t, err) @@ -534,3 +536,46 @@ func TestZFSCommonRecvArgsBuild(t *testing.T) { }) } } + +func TestSendStream_Close_afterRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const foobar = "foobar" + cmd := zfscmd.CommandContext(ctx, "echo", "-n", foobar) + var stderrBuf bytes.Buffer + pipeReader, err := cmd.PipeTo(nil, nil, &stderrBuf) + require.NoError(t, err) + require.NoError(t, cmd.StartPipe()) + + stream := NewSendStream(cmd, pipeReader, &stderrBuf, cancel) + stream.testMode = true + var stdout bytes.Buffer + n, err := io.Copy(&stdout, stream) + require.NoError(t, err) + assert.Equal(t, int64(len(foobar)), n) + assert.Equal(t, foobar, stdout.String()) + + require.NoError(t, stream.Close()) + assert.Empty(t, stderrBuf.Bytes()) +} + +func TestSendStream_Close_noRead(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const foobar = "foobar" + cmd := zfscmd.CommandContext(ctx, "echo", "-n", foobar) + var stderrBuf bytes.Buffer + pipeReader, err := cmd.PipeTo(nil, nil, &stderrBuf) + require.NoError(t, err) + require.NoError(t, cmd.StartPipe()) + + stream := NewSendStream(cmd, pipeReader, &stderrBuf, cancel) + stream.testMode = true + var zfsError *ZFSError + require.ErrorAs(t, stream.Close(), &zfsError) + t.Log(zfsError) + assert.Contains(t, zfsError.Error(), "signal: broken pipe") + assert.Empty(t, stderrBuf.Bytes()) +} diff --git a/zfs/zfscmd/zfscmd.go b/zfs/zfscmd/zfscmd.go index 4fc2a301..4837172b 100644 --- a/zfs/zfscmd/zfscmd.go +++ b/zfs/zfscmd/zfscmd.go @@ -107,9 +107,7 @@ func (c *Cmd) String() string { return s.String() } -func (c *Cmd) log() Logger { - return getLogger(c.ctx) -} +func (c *Cmd) Log() Logger { return c.logWithCmd() } func (c *Cmd) logWithCmd() Logger { if c.cmdLogger == nil { @@ -118,6 +116,10 @@ func (c *Cmd) logWithCmd() Logger { return c.cmdLogger } +func (c *Cmd) log() Logger { + return getLogger(c.ctx) +} + // Start the command. // // This creates a new trace.WithTask as a child task of the ctx passed to @@ -129,7 +131,7 @@ func (c *Cmd) logWithCmd() Logger { // be called repeatedly. func (c *Cmd) Start() error { c.startPre(true) - err := c.startPipe() + err := c.StartPipe() if err != nil { _ = c.WaitPipe() } @@ -266,29 +268,20 @@ func (c *Cmd) TestOnly_ExecCmd() *exec.Cmd { return c.cmd } -func (c *Cmd) Pipe( - stdin io.ReadCloser, stdout, stderr io.Writer, cmds [][]string, +func (c *Cmd) PipeTo(cmds [][]string, stdout io.ReadCloser, stderr io.Writer, ) (io.ReadCloser, error) { - if len(c.cmds) > 0 && c.cmds[0] == c.cmd { - c.SetStdio(Stdio{ - Stdin: nil, - Stdout: stdout, - Stderr: stderr, - }) - } - - for _, pipeCmd := range c.buildPipe(cmds) { - r, err := pipeCmd.StdoutPipe() + c.cmds = append(c.cmds, c.buildPipe(cmds)...) + for _, cmd := range c.cmds { + r, err := cmd.StdoutPipe() if err != nil { return nil, fmt.Errorf( - "create stdout pipe from %q: %w", pipeCmd.String(), err) + "create stdout pipe from %q: %w", cmd.String(), err) } - pipeCmd.Stdin = stdin - pipeCmd.Stderr = stderr - c.cmds = append(c.cmds, pipeCmd) - stdin = r + cmd.Stderr = stderr + cmd.Stdin = stdout + stdout = r } - return stdin, nil + return stdout, nil } func (c *Cmd) buildPipe(cmds [][]string) []*exec.Cmd { @@ -304,25 +297,25 @@ func (c *Cmd) buildPipe(cmds [][]string) []*exec.Cmd { return pipeCmds } -func (c *Cmd) PipeFrom(stdin io.ReadCloser, stdout, stderr io.Writer, - cmds [][]string, +func (c *Cmd) PipeFrom(cmds [][]string, stdin io.ReadCloser, stdout, + stderr io.Writer, ) error { c.cmds = c.cmds[:0] - stdin, err := c.Pipe(stdin, stdout, stderr, cmds) + r, err := c.PipeTo(cmds, stdin, stderr) if err != nil { return err } c.cmds = append(c.cmds, c.cmd) c.SetStdio(Stdio{ - Stdin: stdin, + Stdin: r, Stdout: stdout, Stderr: stderr, }) return nil } -func (c *Cmd) startPipe() error { +func (c *Cmd) StartPipe() error { for _, cmd := range c.cmds { if err := cmd.Start(); err != nil { return fmt.Errorf("start %q: %w", cmd.String(), err) diff --git a/zfs/zfscmd/zfscmd_platform_test.go b/zfs/zfscmd/zfscmd_platform_test.go index ea148eea..c13f3485 100644 --- a/zfs/zfscmd/zfscmd_platform_test.go +++ b/zfs/zfscmd/zfscmd_platform_test.go @@ -179,37 +179,32 @@ func TestCmd_Pipe(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r, w, err := os.Pipe() - require.NoError(t, err) - t.Cleanup(func() { w.Close(); r.Close() }) - cmd := New(ctx).WithPipeLen(len(tt.pipe)). WithCommand(tt.cmd[0], tt.cmd[1:]) - var stdout io.ReadCloser + var stdout bytes.Buffer var stderr bytes.Buffer + var pipeReader io.Reader + var err error + if tt.pipeFrom { - var b bytes.Buffer - require.NoError(t, cmd.PipeFrom(r, &b, &stderr, tt.pipe)) - stdout = io.NopCloser(&b) + require.NoError(t, cmd.PipeFrom(tt.pipe, nil, &stdout, &stderr)) + pipeReader = &stdout } else { - stdout, err = cmd.Pipe(r, w, &stderr, tt.pipe) + pipeReader, err = cmd.PipeTo(tt.pipe, nil, &stderr) require.NoError(t, err) } assert.Equal(t, tt.wantCmdStr, cmd.String()) if tt.startErr { - require.Error(t, cmd.startPipe()) + require.Error(t, cmd.StartPipe()) require.NoError(t, cmd.WaitPipe()) return } + require.NoError(t, cmd.StartPipe()) - require.NoError(t, cmd.startPipe()) - require.NoError(t, w.Close()) - - b, err := io.ReadAll(stdout) + b, err := io.ReadAll(pipeReader) require.NoError(t, err) - require.NoError(t, r.Close()) if tt.waitErr { require.Error(t, cmd.WaitPipe())