Skip to content

Commit 34ce030

Browse files
committed
split TaskDoc into slim version and flexible version
1 parent 217b16c commit 34ce030

File tree

4 files changed

+1217
-96
lines changed

4 files changed

+1217
-96
lines changed

emmet-core/emmet/core/tasks.py

Lines changed: 90 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import json
66
import logging
77
import re
8-
from collections.abc import Mapping
98
from datetime import datetime
109
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
10+
from typing import TYPE_CHECKING, Any, Optional
1211

1312
import numpy as np
1413
from monty.json import MontyDecoder
@@ -29,7 +28,7 @@
2928

3029
from emmet.core import ARROW_COMPATIBLE
3130
from emmet.core.common import convert_datetime
32-
from emmet.core.mpid import MPID
31+
from emmet.core.mpid import MPID, AlphaID
3332
from emmet.core.structure import StructureMetadata
3433
from emmet.core.utils import jsanitize, type_override, utcnow
3534
from emmet.core.vasp.calc_types import (
@@ -43,6 +42,7 @@
4342
from emmet.core.vasp.calculation import (
4443
Calculation,
4544
CalculationInput,
45+
CalculationOutput,
4646
PotcarSpec,
4747
RunStatistics,
4848
VaspObject,
@@ -88,7 +88,6 @@ class OrigInputs(CalculationInput):
8888

8989

9090
class OutputDoc(BaseModel):
91-
# structure: Optional[Structure] = Field(
9291
structure: Optional[AnnotatedStructure] = Field(
9392
None,
9493
title="Output Structure",
@@ -350,121 +349,135 @@ def from_vasp_calc_docs(
350349
)
351350

352351

353-
@type_override(
354-
{
355-
"additional_json": str,
356-
"transformations": str,
357-
"vasp_objects": str,
358-
}
359-
)
360-
class TaskDoc(StructureMetadata, extra="allow", use_enum_values=True):
352+
class ProductionTaskDoc(StructureMetadata):
361353
"""Calculation-level details about VASP calculations that power Materials Project."""
362354

363-
tags: list[str] | None = Field(
364-
None, title="tag", description="Metadata tagged to a given task."
365-
)
366-
dir_name: str | None = Field(None, description="The directory for this VASP task")
367-
state: TaskState | None = Field(None, description="State of this calculation")
368-
369-
calcs_reversed: list[Calculation] | None = Field(
355+
batch_id: Optional[str] = Field(
370356
None,
371-
title="Calcs reversed data",
372-
description="Detailed data for each VASP calculation contributing to the task document.",
357+
description="Identifier for this calculation; should provide rough information about the calculation origin and purpose.",
373358
)
374-
375-
# structure: Structure | None = Field(
376-
structure: AnnotatedStructure | None = Field(
377-
None, description="Final output structure from the task"
359+
calc_type: Optional[CalcType] = Field(
360+
None, description="The functional and task type used in the calculation."
378361
)
379-
380-
task_type: TaskType | CalcType | None = Field(
381-
None, description="The type of calculation."
362+
completed_at: Optional[datetime] = Field(
363+
None, description="Timestamp for when this task was completed"
382364
)
383-
384-
run_type: RunType | None = Field(
385-
None, description="The functional used in the calculation."
365+
dir_name: Optional[str] = Field(
366+
None, description="The directory for this VASP task"
386367
)
387-
388-
calc_type: CalcType | None = Field(
389-
None, description="The functional and task type used in the calculation."
368+
icsd_id: Optional[str | int] = Field(
369+
None, description="Inorganic Crystal Structure Database id of the structure"
390370
)
391-
392-
task_id: MPID | str | None = Field(
371+
input: CalculationInput | None = Field(
393372
None,
394-
description="The (task) ID of this calculation, used as a universal reference across property documents."
395-
"This comes in the form: mp-******.",
373+
description="The input structure used to generate the current task document.",
396374
)
397-
398-
orig_inputs: OrigInputs | None = Field(
399-
None,
400-
description="The exact set of input parameters used to generate the current task document.",
375+
last_updated: Optional[datetime] = Field(
376+
utcnow(),
377+
description="Timestamp for the most recent calculation for this task document",
401378
)
402-
403-
input: InputDoc | None = Field(
379+
orig_inputs: CalculationInput | None = Field(
404380
None,
405-
description="The input structure used to generate the current task document.",
381+
description="The exact set of input parameters used to generate the current task document.",
406382
)
407-
408-
output: OutputDoc | None = Field(
383+
output: CalculationOutput | None = Field(
409384
None,
410385
description="The exact set of output parameters used to generate the current task document.",
411386
)
412-
413-
included_objects: list[VaspObject] | None = Field(
414-
None, description="List of VASP objects included with this task document"
387+
run_type: Optional[RunType] = Field(
388+
None, description="The functional used in the calculation."
415389
)
416-
vasp_objects: dict[VaspObject, Any] | None = Field(
417-
None, description="Vasp objects associated with this task"
390+
structure: Optional[AnnotatedStructure] = Field(
391+
None, description="Final output structure from the task"
418392
)
419-
entry: ComputedEntry | None = Field(
420-
None, description="The ComputedEntry from the task doc"
393+
tags: list[str] | None = Field(
394+
None, title="tag", description="Metadata tagged to a given task."
421395
)
422-
task_label: str | None = Field(None, description="A description of the task")
423-
author: str | None = Field(
424-
None, description="Author extracted from transformations"
396+
task_id: AlphaID | MPID | str | None = Field(
397+
None,
398+
description="The (task) ID of this calculation, used as a universal reference across property documents."
399+
"This comes in the form: mp-******.",
425400
)
426-
icsd_id: int | None = Field(
427-
None, description="Inorganic Crystal Structure Database id of the structure"
401+
task_type: Optional[TaskType | CalcType] = Field(
402+
None, description="The type of calculation."
428403
)
429404
transformations: Any | None = Field(
430405
None,
431406
description="Information on the structural transformations, parsed from a "
432407
"transformations.json file",
433408
)
434-
additional_json: dict[str, Any] | None = Field(
435-
None, description="Additional json loaded from the calculation directory"
409+
vasp_objects: Optional[dict[VaspObject, Any]] = Field(
410+
None, description="Vasp objects associated with this task"
436411
)
437412

438-
custodian: list[CustodianDoc] | None = Field(
439-
None,
440-
title="Calcs reversed data",
441-
description="Detailed custodian data for each VASP calculation contributing to the task document.",
442-
)
413+
@model_validator(mode="before")
414+
@classmethod
415+
def set_prod_model_pre_fields(cls, values: Any) -> Any:
416+
"""Ensure all important model fields are set and refreshed."""
417+
values["last_updated"] = convert_datetime(
418+
cls, values.get("last_updated", utcnow())
419+
)
420+
421+
if (batch_id := values.get("batch_id")) is not None:
422+
invalid_chars = set(
423+
char
424+
for char in batch_id
425+
if (not char.isalnum()) and (char not in {"-", "_"})
426+
)
427+
if len(invalid_chars) > 0:
428+
raise ValueError(
429+
f"Invalid characters in batch_id: {' '.join(invalid_chars)}"
430+
)
431+
443432

444-
analysis: AnalysisDoc | None = Field(
433+
@type_override(
434+
{
435+
"additional_json": str,
436+
"transformations": str,
437+
"vasp_objects": str,
438+
}
439+
)
440+
class TaskDoc(ProductionTaskDoc, extra="allow", use_enum_values=True):
441+
"""Flexible wrapper around ProductionTaskDoc"""
442+
443+
additional_json: Optional[dict[str, Any]] = Field(
444+
None, description="Additional json loaded from the calculation directory"
445+
)
446+
analysis: Optional[AnalysisDoc] = Field(
445447
None,
446448
title="Calculation Analysis",
447449
description="Some analysis of calculation data after collection.",
448450
)
449-
450-
last_updated: datetime = Field(
451-
default_factory=utcnow,
452-
description="Timestamp for the most recent calculation for this task document",
451+
author: Optional[str] = Field(
452+
None, description="Author extracted from transformations"
453453
)
454-
455-
completed_at: datetime | None = Field(
456-
None, description="Timestamp for when this task was completed"
454+
calcs_reversed: Optional[list[Calculation]] = Field(
455+
None,
456+
title="Calcs reversed data",
457+
description="Detailed data for each VASP calculation contributing to the task document.",
457458
)
458-
459-
batch_id: str | None = Field(
459+
custodian: Optional[list[CustodianDoc]] = Field(
460460
None,
461-
description="Identifier for this calculation; should provide rough information about the calculation origin and purpose.",
461+
title="Calcs reversed data",
462+
description="Detailed custodian data for each VASP calculation contributing to the task document.",
463+
)
464+
entry: Optional[ComputedEntry] = Field(
465+
None, description="The ComputedEntry from the task doc"
466+
)
467+
included_objects: Optional[list[VaspObject]] = Field(
468+
None, description="List of VASP objects included with this task document"
469+
)
470+
output: Optional[OutputDoc] = Field(
471+
None,
472+
description="The exact set of output parameters used to generate the current task document.",
462473
)
463474

464475
run_stats: dict[str, RunStatistics] | None = Field(
465476
None,
466477
description="Summary of runtime statistics for each calculation in this task",
467478
)
479+
state: Optional[TaskState] = Field(None, description="State of this calculation")
480+
task_label: Optional[str] = Field(None, description="A description of the task")
468481

469482
# Note that private fields are needed because TaskDoc permits extra info
470483
# added to the model, unlike TaskDocument. Because of this, when pydantic looks up
@@ -554,25 +567,6 @@ def deserialize_entry(cls, entry):
554567
@classmethod
555568
def set_model_pre_fields(cls, values: Any) -> Any:
556569
"""Ensure all important model fields are set and refreshed."""
557-
558-
# Make sure that the datetime field is properly formatted
559-
# (Unclear when this is not the case, please leave comment if observed)
560-
values["last_updated"] = convert_datetime(
561-
cls, values.get("last_updated", utcnow())
562-
)
563-
564-
# Ensure batch_id includes only valid characters
565-
if (batch_id := values.get("batch_id")) is not None:
566-
invalid_chars = set(
567-
char
568-
for char in batch_id
569-
if (not char.isalnum()) and (char not in {"-", "_"})
570-
)
571-
if len(invalid_chars) > 0:
572-
raise ValueError(
573-
f"Invalid characters in batch_id: {' '.join(invalid_chars)}"
574-
)
575-
576570
# Always refresh task_type, calc_type, run_type
577571
# if attributes containing input sets are available.
578572
# See, e.g. https://github.com/materialsproject/emmet/issues/960

0 commit comments

Comments
 (0)