Skip to content

Commit 5aef512

Browse files
committed
Merge branch '469-damask-result-protected-false-not-working-for-add_curl-etc' into 'development'
Resolve "damask.Result: protected=False not working for add_curl etc." Closes #469 See merge request damask/DAMASK!1023
2 parents 064d7ab + 76ad336 commit 5aef512

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

python/damask/_result.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1447,8 +1447,7 @@ def gradient(f: DADF5Dataset, size: np.ndarray) -> DADF5Dataset:
14471447
def _add_generic_grid(self,
14481448
func: Callable[..., DADF5Dataset],
14491449
datasets: Dict[str, str],
1450-
args: Dict[str, str] = {},
1451-
constituents = None):
1450+
args: Dict[str,Any]):
14521451
"""
14531452
General function to add data on a regular grid.
14541453
@@ -1460,7 +1459,7 @@ def _add_generic_grid(self,
14601459
datasets : dictionary
14611460
Details of the datasets to be used:
14621461
{arg (name to which the data is passed in func): label (in DADF5 file)}.
1463-
args : dictionary, optional
1462+
args : dictionary
14641463
Arguments parsed to func.
14651464
14661465
"""
@@ -1484,13 +1483,17 @@ def _add_generic_grid(self,
14841483
r = func(**dataset,**args)
14851484
result = grid_filters.ravel(r['data'])
14861485
for x in self._visible[ty[0]+'s']:
1486+
path = '/'.join(['/',increment[0],ty[0],x,field[0]])
14871487
if ty[0] == 'phase':
14881488
result1 = result[at_cell_ph[0][x]]
14891489
if ty[0] == 'homogenization':
14901490
result1 = result[at_cell_ho[x]]
1491-
1492-
path = '/'.join(['/',increment[0],ty[0],x,field[0]])
1493-
h5_dataset = f[path].create_dataset(r['label'],data=result1)
1491+
if not self._protected and '/'.join([path,r['label']]) in f:
1492+
h5_dataset = f['/'.join([path,r['label']])]
1493+
h5_dataset[...] = result1
1494+
h5_dataset.attrs['overwritten'] = True
1495+
else:
1496+
h5_dataset = f[path].create_dataset(r['label'],data=result1)
14941497

14951498
h5_dataset.attrs['created'] = util.time_stamp()
14961499

@@ -1503,7 +1506,7 @@ def _add_generic_grid(self,
15031506
def _add_generic_pointwise(self,
15041507
func: Callable[..., DADF5Dataset],
15051508
datasets: Dict[str, str],
1506-
args: Dict[str, Any] = {}):
1509+
args: Optional[Dict[str, Any]] = None):
15071510
"""
15081511
General function to add pointwise data.
15091512
@@ -1519,6 +1522,7 @@ def _add_generic_pointwise(self,
15191522
Arguments parsed to func.
15201523
15211524
"""
1525+
args = args if args else {}
15221526

15231527
def job_pointwise(group: str,
15241528
callback: Callable[..., DADF5Dataset],

python/tests/test_Result.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,29 @@ def test_getters(self,default):
145145
fields.append(homogenization[f])
146146
assert len(fields) > 0
147147

148+
@pytest.mark.parametrize('protected', [True, False])
149+
@pytest.mark.parametrize('func', [lambda default: default._add_generic_grid,
150+
lambda default: default._add_generic_pointwise])
151+
def test_add_generic_dataset_overwrite(self, default, protected, func):
152+
def add_test_dataset(f,dummy):
153+
return {
154+
'data': f['data'],
155+
'label': f'|{f["label"]}|',
156+
'meta': {
157+
'unit': 0,
158+
'description': 'test data',
159+
'creator': 'add_test_dataset'
160+
}
161+
}
162+
163+
default._protected = protected
164+
func(default)(add_test_dataset, {'f': 'F_e'},{'dummy':'test'})
165+
if protected:
166+
with pytest.raises(ValueError):
167+
func(default)(add_test_dataset, {'f': 'F_e'},{'dummy':'test'})
168+
else:
169+
func(default)(add_test_dataset, {'f': 'F_e'},{'dummy':'test'})
170+
148171
def test_add_invalid(self,default):
149172
default.add_absolute('xxxx')
150173

@@ -162,7 +185,7 @@ def test_add_calculation(self,default,tmp_path,mode):
162185
default.add_calculation('2.0*np.abs(#F#)-1.0','x','-','my notes')
163186
else:
164187
with open(tmp_path/'f.py','w') as f:
165-
f.write("import numpy as np\ndef my_func(field):\n return 2.0*np.abs(field)-1.0\n")
188+
f.write('import numpy as np\ndef my_func(field):\n return 2.0*np.abs(field)-1.0\n')
166189
sys.path.insert(0,str(tmp_path))
167190
import f
168191
default.enable_user_function(f.my_func)

0 commit comments

Comments
 (0)