Skip to content

Commit a092b63

Browse files
authored
Tuple -> tuple, List -> list [pr] (tinygrad#8936)
1 parent d5183e1 commit a092b63

File tree

9 files changed

+43
-47
lines changed

9 files changed

+43
-47
lines changed

examples/handcode_opt.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import List, Tuple
21
from extra.models.resnet import ResNet50
32
from extra.mcts_search import mcts_search
43
from examples.mlperf.helpers import get_mlperf_bert_model
@@ -79,7 +78,7 @@ def get_sched_bert():
7978
rawbufs = bufs_from_lin(Kernel(si.ast))
8079

8180
# "linearize" the op into uops in different ways
82-
lins: List[Tuple[Kernel, str]] = []
81+
lins: list[tuple[Kernel, str]] = []
8382

8483
# always try hand coded opt
8584
lin = Kernel(si.ast, opts=device.renderer)

examples/mlperf/initializers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Union, Tuple
2+
from typing import Union
33

44
from tinygrad import Tensor, nn, dtypes
55
from tinygrad.helpers import prod, argfix
@@ -56,7 +56,7 @@ def __call__(self, idx:Tensor) -> Tensor:
5656
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
5757

5858
class LayerNormBert:
59-
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
59+
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
6060
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
6161
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
6262
self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)

extra/models/llama.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Union, Optional, Dict, Any
1+
from typing import Union, Optional, Any
22
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
33
from tinygrad.helpers import getenv
44

@@ -15,7 +15,7 @@ def complex_mult(A, c, d):
1515
co = a*d + b*c
1616
return ro.cat(co, dim=-1)
1717

18-
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
18+
def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> tuple[Tensor, Tensor]:
1919
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
2020
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
2121
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -181,7 +181,7 @@ def __call__(self, tokens:Tensor, start_pos:int, temperature:float=0.0, top_k:in
181181

182182
# *** helpers ***
183183

184-
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
184+
def convert_from_huggingface(weights:dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, permute_layers: bool = True):
185185
def permute(v: Tensor, n_heads: int):
186186
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
187187

@@ -207,7 +207,7 @@ def permute(v: Tensor, n_heads: int):
207207
sd[keymap[k]] = v
208208
return sd
209209

210-
def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
210+
def convert_from_gguf(weights:dict[str, Tensor], model: Transformer):
211211
keymap = {
212212
"token_embd.weight": "tok_embeddings.weight",
213213
**{f"blk.{l}.attn_norm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
@@ -222,7 +222,7 @@ def convert_from_gguf(weights:Dict[str, Tensor], model: Transformer):
222222
sd["output.weight"] = weights["token_embd.weight"]
223223
return sd
224224

225-
def fix_bf16(weights:Dict[Any, Tensor]):
225+
def fix_bf16(weights:dict[Any, Tensor]):
226226
if getenv("SUPPORT_BF16", 1):
227227
# TODO: without casting to float16, 70B llama OOM on tinybox.
228228
return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}

extra/optimization/helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# stuff needed to unpack a kernel
2-
from typing import Tuple
32
from tinygrad import Variable
43
from tinygrad.codegen.kernel import Opt, OptOps
54
from tinygrad.ops import UOp, Ops, KernelInfo

test/helpers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Callable, Optional, Tuple
2+
from typing import Callable, Optional
33
import numpy as np
44
from tinygrad import Tensor, dtypes
55
from tinygrad.ops import UOp, Ops, sint
@@ -40,14 +40,14 @@ def rand_for_dtype(dt:DType, size:int):
4040
return np.random.choice([True, False], size=size)
4141
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
4242

43-
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
43+
def ast_const(dtype:DType, val:ConstType, shape:tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[tuple[UOp]]=None) -> UOp:
4444
if st_src is None:
4545
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
4646
st = unwrap(st_src[0].st)
4747
if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),))
4848
return UOp.const(dtype, val).valid(st)
4949

50-
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
50+
def timeit(fxn:Callable[..., T], *args, **kwargs) -> tuple[T, float]:
5151
st = time.perf_counter_ns()
5252
ret = fxn(*args, **kwargs)
5353
return ret, (time.perf_counter_ns()-st)*1e-6

test/test_linearizer.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Union
1+
from typing import Union
22
import numpy as np
33
import unittest
44
from dataclasses import replace
@@ -10,13 +10,12 @@
1010
from tinygrad.device import Device, Buffer, is_dtype_supported
1111
from tinygrad.shape.shapetracker import ShapeTracker
1212
from tinygrad.shape.view import View
13-
# from tinygrad.ops import Variable
1413
from tinygrad.tensor import Tensor, _to_np_dtype
1514
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
1615
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
1716
from tinygrad.dtype import DType, dtypes
1817

19-
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
18+
def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer]]:
2019
if isinstance(r, Tensor): r = [r]
2120
s = Tensor.schedule(*r)
2221
run_schedule(s[:-1]) # run all kernels except the last one
@@ -1752,30 +1751,30 @@ def test_matvec(self):
17521751
assert k.local_dims == 1
17531752
assert k.upcasted == 1
17541753

1755-
def helper_linearizer_ast(ast:UOp, inputs:List[Tensor], *args, **kwargs):
1754+
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
17561755
assert isinstance(ast, UOp), "ast must be UOp"
17571756
inbufs = [x.lazydata.base.buffer for x in inputs]
17581757
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.st_arg.size, out.src[2].dtype).allocate() \
17591758
for out in ast.src]
17601759
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
17611760

1762-
def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
1761+
def helper_linearizer_opt(r:Union[Tensor, list[Tensor]], *args, **kwargs):
17631762
realized_ast, real_bufs = helper_realized_ast(r)
17641763
return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
17651764

1766-
def copyout_outputs(lin:Kernel, outbufs:List[Buffer]) -> List[np.ndarray]:
1765+
def copyout_outputs(lin:Kernel, outbufs:list[Buffer]) -> list[np.ndarray]:
17671766
ret = []
17681767
for i,x in enumerate(outbufs):
1769-
shape: Tuple[int, ...] = lin.ast.src[i].st_arg.shape
1768+
shape: tuple[int, ...] = lin.ast.src[i].st_arg.shape
17701769
ret.append(np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)).reshape(shape))
17711770
return ret
17721771

1773-
def reset_bufs(bufs:List[Buffer]):
1772+
def reset_bufs(bufs:list[Buffer]):
17741773
for buf in bufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
17751774

1776-
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[],
1777-
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
1778-
lins: List[Kernel] = []
1775+
def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[],
1776+
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> list[Kernel]:
1777+
lins: list[Kernel] = []
17791778
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
17801779

17811780
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))

test/test_uops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Any, List
1+
from typing import Optional, Any
22
import unittest, math
33
import numpy as np
44
from tinygrad.shape.shapetracker import ShapeTracker
@@ -17,7 +17,7 @@
1717
from tinygrad.device import is_dtype_supported
1818
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
1919

20-
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
20+
def to_uops_list(u:list[UOp], opts=None, skip_check=False) -> list[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
2121

2222
def _uops_to_prg(uops_list):
2323
uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
@@ -26,7 +26,7 @@ def _uops_to_prg(uops_list):
2626
return CompiledRunner(ProgramSpec("test", src, Device.DEFAULT, uops=uops,
2727
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
2828

29-
def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
29+
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
3030
uops.append(UOp(uop, dtype, tuple(src), arg))
3131
return uops[-1]
3232

tinygrad/runtime/ops_dsp.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
from typing import Tuple, Any, List
32
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, time, struct
43
assert sys.platform != 'win32'
54
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
@@ -20,7 +19,7 @@ class DSPRenderer(ClangRenderer):
2019
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
2120
Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"}
2221

23-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
22+
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
2423
ret = super().render_kernel(function_name, kernel, bufs, uops, prefix)
2524
msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params;
2625
short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);',
@@ -55,7 +54,7 @@ class DSPProgram:
5554
def __init__(self, dev:DSPDevice, name:str, lib:bytes):
5655
self.dev, self.lib = dev, lib
5756

58-
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
57+
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
5958
if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}")
6059

6160
pra, fds, attrs, _ = rpc_prep_args(ins=[var_vals_mv:=memoryview(bytearray((len(bufs)+len(vals))*4)), off_mv:=memoryview(bytearray(len(bufs)*4))],
@@ -66,7 +65,7 @@ def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
6665
return timer[0] / 1e6
6766

6867
class DSPBuffer:
69-
def __init__(self, va_addr:int, size:int, share_info:Any, offset:int=0):
68+
def __init__(self, va_addr:int, size:int, share_info, offset:int=0):
7069
self.va_addr, self.size, self.share_info, self.offset = va_addr, size, share_info, offset
7170

7271
class DSPAllocator(Allocator):
@@ -229,7 +228,7 @@ def run(self):
229228
# ***** mock DSP *****
230229

231230
class MockDSPRenderer(DSPRenderer):
232-
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
231+
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
233232
ret = ClangRenderer.render_kernel(self, function_name, kernel, bufs, uops, prefix)
234233
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
235234
msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
@@ -254,7 +253,7 @@ def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str
254253

255254
class MockDSPProgram:
256255
def __init__(self, name:str, lib:bytes): self.lib = lib
257-
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
256+
def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
258257
with tempfile.NamedTemporaryFile(suffix=".out") as dsp_lib:
259258
dsp_lib.write(self.lib)
260259
dsp_lib.flush()

tinygrad/tensor.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
44
from contextlib import ContextDecorator
5-
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
5+
from typing import Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
66
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
77
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
88
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
@@ -68,7 +68,7 @@ def get_shape(x) -> tuple[int, ...]:
6868
if not all_same(subs:=[get_shape(xi) for xi in x]): raise ValueError(f"inhomogeneous shape from {x}")
6969
return (len(subs),) + (subs[0] if subs else ())
7070

71-
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> UOp:
71+
def _frompy(x:Union[list, tuple, bytes], dtype:DType) -> UOp:
7272
if isinstance(x, bytes): ret, data = UOp.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
7373
else:
7474
ret = UOp.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON")
@@ -131,7 +131,7 @@ class Tensor(SimpleMathTrait):
131131
training: ClassVar[bool] = False
132132
no_grad: ClassVar[bool] = False
133133

134-
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
134+
def __init__(self, data:Union[None, ConstType, bytes, list, tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
135135
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
136136
if dtype is not None: dtype = to_dtype(dtype)
137137
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
@@ -329,7 +329,7 @@ def item(self) -> ConstType:
329329
assert self.numel() == 1, "must have one element for item"
330330
return self.data()[(0,) * len(self.shape)]
331331

332-
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The List is Sequence because mypy expects memoryview.tolist() -> list[int]
332+
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
333333
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
334334
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
335335
"""
@@ -1185,7 +1185,7 @@ def __getitem__(self, indices) -> Tensor:
11851185
"""
11861186
Retrieve a sub-tensor using indexing.
11871187
1188-
Supported Index Types: `int | slice | Tensor | None | List | Tuple | Ellipsis`
1188+
Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`
11891189
11901190
Examples:
11911191
```python exec="true" source="above" session="tensor" result="python"
@@ -2036,7 +2036,7 @@ def _resolve_pool_pads(self, padding:Union[int, Sequence[int]], dims:int) -> Seq
20362036
raise ValueError(f"Padding must be an int or a sequence of length {dims} or {2*dims}, but got {padding=} for {self.shape=} with {dims=}.")
20372037
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
20382038

2039-
def _apply_ceil_mode(self, pads:Sequence[int], k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int]) -> List[int]:
2039+
def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:Union[tuple[int, ...], int], d_:Union[tuple[int, ...], int]) -> list[int]:
20402040
(d_,s_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_)), self.shape[-len(k_):]
20412041
pads, grouped_pads = list(pads), _flat_to_grouped(pads)
20422042
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
@@ -2059,10 +2059,10 @@ def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil
20592059
1. `int` (single value):
20602060
Applies the same padding value uniformly to all spatial dimensions.
20612061
2062-
2. `Tuple[int, ...]` (length = number of spatial dimensions):
2062+
2. `tuple[int, ...]` (length = number of spatial dimensions):
20632063
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
20642064
2065-
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2065+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
20662066
Specifies explicit padding for each side of each spatial dimension in the form
20672067
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
20682068
@@ -2106,10 +2106,10 @@ def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil
21062106
1. `int` (single value):
21072107
Applies the same padding value uniformly to all spatial dimensions.
21082108
2109-
2. `Tuple[int, ...]` (length = number of spatial dimensions):
2109+
2. `tuple[int, ...]` (length = number of spatial dimensions):
21102110
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
21112111
2112-
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2112+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
21132113
Specifies explicit padding for each side of each spatial dimension in the form
21142114
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
21152115
@@ -2144,10 +2144,10 @@ def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1,
21442144
1. `int` (single value):
21452145
Applies the same padding value uniformly to all spatial dimensions.
21462146
2147-
2. `Tuple[int, ...]` (length = number of spatial dimensions):
2147+
2. `tuple[int, ...]` (length = number of spatial dimensions):
21482148
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
21492149
2150-
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2150+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
21512151
Specifies explicit padding for each side of each spatial dimension in the form
21522152
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
21532153
@@ -2217,10 +2217,10 @@ def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1,
22172217
1. `int` (single value):
22182218
Applies the same padding value uniformly to all spatial dimensions.
22192219
2220-
2. `Tuple[int, ...]` (length = number of spatial dimensions):
2220+
2. `tuple[int, ...]` (length = number of spatial dimensions):
22212221
Specifies a distinct padding value for each spatial dimension in the form `(padding_height, padding_width, ...)`.
22222222
2223-
3. `Tuple[int, ...]` (length = 2 * number of spatial dimensions):
2223+
3. `tuple[int, ...]` (length = 2 * number of spatial dimensions):
22242224
Specifies explicit padding for each side of each spatial dimension in the form
22252225
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
22262226

0 commit comments

Comments
 (0)