|
1 | 1 | import io
|
2 | 2 |
|
3 |
| -import zstandard |
| 3 | +from pyzstd import ZstdFile, ZstdError |
4 | 4 | import tarfile
|
5 | 5 |
|
6 | 6 |
|
7 | 7 | class ExtTarFile(tarfile.TarFile):
|
8 | 8 | """Extends TarFile to support zstandard"""
|
9 | 9 |
|
10 | 10 | @classmethod
|
11 |
| - def zstdopen(cls, name, mode="r", fileobj=None, cctx=None, dctx=None, **kwargs): # type: ignore |
12 |
| - """Open zstd compressed tar archive name for reading or writing. |
13 |
| - Appending is not allowed. |
14 |
| - """ |
15 |
| - if mode not in ("r"): |
16 |
| - raise ValueError("mode must be 'r'") |
| 11 | + def zstdopen(cls, name, mode="r", fileobj=None, **kwargs): # type: ignore |
| 12 | + """Open zstd compressed tar archive""" |
17 | 13 |
|
| 14 | + if mode not in ("r", "w", "x", "a"): |
| 15 | + raise ValueError("mode must be 'r', 'w' or 'x' or 'a'") |
| 16 | + |
| 17 | + zstfileobj = None |
18 | 18 | try:
|
19 |
| - zobj = zstandard.open(fileobj or name, mode + "b", cctx=cctx, dctx=dctx) |
20 |
| - with zobj: |
21 |
| - data = zobj.read() |
22 |
| - except (zstandard.ZstdError, EOFError) as e: |
| 19 | + zstfileobj = ZstdFile(fileobj or name, mode) |
| 20 | + if "r" in mode: |
| 21 | + zstfileobj.peek(1) # raises ZstdError if not a zstd file |
| 22 | + except ZstdError as e: |
| 23 | + if zstfileobj is not None: |
| 24 | + zstfileobj.close() |
23 | 25 | raise tarfile.ReadError("not a zstd file") from e
|
24 | 26 |
|
25 |
| - fileobj = io.BytesIO(data) |
26 |
| - t = cls.taropen(name, mode, fileobj, **kwargs) |
| 27 | + try: |
| 28 | + t = cls.taropen(name, mode, zstfileobj, **kwargs) |
| 29 | + except Exception: |
| 30 | + zstfileobj.close() |
| 31 | + raise |
| 32 | + |
27 | 33 | t._extfileobj = False
|
28 | 34 | return t
|
29 | 35 |
|
|
0 commit comments