diff --git a/broker/session.py b/broker/session.py index 1e55d104..4fb72f1d 100644 --- a/broker/session.py +++ b/broker/session.py @@ -81,6 +81,8 @@ def sftp_read(self, source, destination=None): """read a remote file into a local destination""" if not destination: destination = source + elif destination.endswith("/"): + destination = destination + Path(source).name # create the destination path if it doesn't exist destination = Path(destination) destination.parent.mkdir(parents=True, exist_ok=True) @@ -94,19 +96,25 @@ def sftp_read(self, source, destination=None): for size, data in remote: local.write(data) - def sftp_write(self, source, destination=None): + def sftp_write(self, source, destination=None, ensure_dir=True): """sftp write a local file to a remote destination""" if not destination: destination = source + elif destination.endswith("/"): + destination = destination + Path(source).name data = Path(source).read_bytes() + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") sftp = self.session.sftp_init() with sftp.open(destination, FILE_FLAGS, SFTP_MODE) as remote: remote.write(data) - def remote_copy(self, source, dest_host): + def remote_copy(self, source, dest_host, ensure_dir=True): """Copy a file from this host to another""" sftp_down = self.session.sftp_init() sftp_up = dest_host.session.session.sftp_init() + if ensure_dir: + dest_host.run(f"mkdir -p {Path(source).absolute().parent}") with sftp_down.open( source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR ) as download: @@ -114,10 +122,12 @@ def remote_copy(self, source, dest_host): for size, data in download: upload.write(data) - def scp_write(self, source, destination=None): + def scp_write(self, source, destination=None, ensure_dir=True): """scp write a local file to a remote destination""" if not destination: destination = source + elif destination.endswith("/"): + destination = destination + Path(source).name fileinfo = os.stat(source) chan = self.session.scp_send64( destination, @@ -126,6 +136,8 @@ def scp_write(self, source, destination=None): fileinfo.st_mtime, fileinfo.st_atime, ) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") with open(source, "rb") as local: for data in local: chan.write(data) @@ -221,7 +233,7 @@ def disconnect(self): """Needed for simple compatability with Session""" pass - def sftp_write(self, source, destination=None): + def sftp_write(self, source, destination=None, ensure_dir=True): """Add one of more files to the container""" # ensure source is a list of Path objects if not isinstance(source, list): @@ -232,15 +244,17 @@ def sftp_write(self, source, destination=None): for src in source: if not Path(src).exists(): raise FileNotFoundError(src) - destination = Path(destination) or source[0].parent + destination = destination or f"{source[0].parent}/" # Files need to be added to a tarfile with helpers.temporary_tar(source) as tar: logger.debug( f"{self._cont_inst.hostname} adding file(s) {source} to {destination}" ) - # if the destination is a file, create the parent path - if destination.is_file(): - self.execute(f"mkdir -p {destination.parent}") + if ensure_dir: + if destination.endswith("/"): + self.run(f"mkdir -m 666 -p {destination}") + else: + self.run(f"mkdir -m 666 -p {Path(destination).parent}") self._cont_inst._cont_inst.put_archive(str(destination), tar.read_bytes()) def sftp_read(self, source, destination=None): diff --git a/tests/functional/test_containers.py b/tests/functional/test_containers.py index 2f853665..5b6c5fb0 100644 --- a/tests/functional/test_containers.py +++ b/tests/functional/test_containers.py @@ -84,6 +84,6 @@ def test_container_e2e_mp(): res = c_host.execute("hostname") assert res.stdout.strip() == c_host.hostname # Test that a file can be uploaded to the container - c_host.session.sftp_write("broker_settings.yaml", "/root") - res = c_host.execute("ls") + c_host.session.sftp_write("broker_settings.yaml", "/tmp/fake/") + res = c_host.execute("ls /tmp/fake") assert "broker_settings.yaml" in res.stdout diff --git a/tests/functional/test_satlab.py b/tests/functional/test_satlab.py index df4271cb..e837dbc4 100644 --- a/tests/functional/test_satlab.py +++ b/tests/functional/test_satlab.py @@ -68,3 +68,6 @@ def test_tower_host(): with Broker(workflow="deploy-base-rhel") as r_host: res = r_host.execute("hostname") assert res.stdout.strip() == r_host.hostname + r_host.session.sftp_write("broker_settings.yaml", "/tmp/fake/") + res = r_host.execute("ls /tmp/fake") + assert "broker_settings.yaml" in res.stdout