-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathsplit_combine.py
106 lines (92 loc) · 3.59 KB
/
split_combine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
"""
Created on 2018/12/19 14:00
# SplitComb是一个类,类的主要参数有:side_len=144, max_stride=16, stride=4, margin=32, pad_value=170
# SplitComb.split对数据进行padding,以及z、x、y轴上的处理操作,SplitComb.combine对patch数据进行合并操作
"""
import torch
import numpy as np
class SplitComb():
def __init__(self,side_len,max_stride,stride,margin,pad_value):
self.side_len = side_len
self.max_stride = max_stride
self.stride = stride
self.margin = margin
self.pad_value = pad_value
def split(self, data, side_len = None, max_stride = None, margin = None):
if side_len==None:
side_len = self.side_len # 144
if max_stride == None:
max_stride = self.max_stride # 16 margin=32
if margin == None:
margin = self.margin
assert(side_len > margin)
assert(side_len % max_stride == 0)
assert(margin % max_stride == 0)
splits = []
_, z, h, w = data.shape
nz = int(np.ceil(float(z) / side_len))
nh = int(np.ceil(float(h) / side_len))
nw = int(np.ceil(float(w) / side_len))
nzhw = [nz,nh,nw]
self.nzhw = nzhw
pad = [ [0, 0],
[margin, nz * side_len - z + margin],
[margin, nh * side_len - h + margin],
[margin, nw * side_len - w + margin]]
data = np.pad(data, pad, 'edge') # 图像边缘值的像素填充
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len + 2 * margin
sh = ih * side_len
eh = (ih + 1) * side_len + 2 * margin
sw = iw * side_len
ew = (iw + 1) * side_len + 2 * margin
split = data[np.newaxis, :, sz:ez, sh:eh, sw:ew]
splits.append(split)
splits = np.concatenate(splits, 0)
return splits,nzhw
def combine(self, output, nzhw = None, side_len=None, stride=None, margin=None):
if side_len==None:
side_len = self.side_len
if stride == None:
stride = self.stride
if margin == None:
margin = self.margin
if nzhw is None:
nz = self.nz
nh = self.nh
nw = self.nw
else:
nz,nh,nw = nzhw
assert(side_len % stride == 0)
assert(margin % stride == 0)
side_len /= stride # 36
margin /= stride # 8
splits = []
for i in range(len(output)):
splits.append(output[i])
output = -1000000 * np.ones((
nz * side_len,
nh * side_len,
nw * side_len,
splits[0].shape[3],
splits[0].shape[4]), np.float32)
idx = 0
for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
sz = iz * side_len
ez = (iz + 1) * side_len
sh = ih * side_len
eh = (ih + 1) * side_len
sw = iw * side_len
ew = (iw + 1) * side_len
# print(splits[0].shape) # 切分后的维度(52, 52, 52, 3, 5)
## margin=8,side_len=36
split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len]
output[sz:ez, sh:eh, sw:ew] = split
idx += 1
return output