From 187444c0d2ea0eef341e3792daf8d9859272e7f6 Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Tue, 4 Feb 2025 11:58:58 +0100 Subject: [PATCH] sweep: #8026 fix: make the setting of inputDataBulk extendable --- .../Client/WorkflowTasks.py | 49 ++++++++++++------- .../Client/test/Test_Client_WorkflowTasks.py | 26 ++++++++++ 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/DIRAC/TransformationSystem/Client/WorkflowTasks.py b/src/DIRAC/TransformationSystem/Client/WorkflowTasks.py index 2f8366da217..5073eb963d7 100644 --- a/src/DIRAC/TransformationSystem/Client/WorkflowTasks.py +++ b/src/DIRAC/TransformationSystem/Client/WorkflowTasks.py @@ -7,7 +7,7 @@ from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername from DIRAC.Core.Security.ProxyInfo import getProxyInfo -from DIRAC.Core.Utilities.DErrno import ETSDATA, ETSUKN +from DIRAC.Core.Utilities.DErrno import ETSUKN from DIRAC.Core.Utilities.List import fromChar from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader from DIRAC.Interfaces.API.Job import Job @@ -83,7 +83,11 @@ def __init__( self.outputDataModule_o = None self.objectLoader = ObjectLoader() - def prepareTransformationTasks(self, transBody, taskDict, owner="", ownerGroup="", bulkSubmissionFlag=False): + self.parametricSequencedKeys = ["JOB_ID", "PRODUCTION_ID", "InputData"] + + def prepareTransformationTasks( + self, transBody, taskDict, owner="", ownerGroup="", bulkSubmissionFlag=False + ): """Prepare tasks, given a taskDict, that is created (with some manipulation) by the DB jobClass is by default "DIRAC.Interfaces.API.Job.Job". An extension of it also works. @@ -191,22 +195,7 @@ def _prepareTasksBulk(self, transBody, taskDict, owner, ownerGroup): method=method, ) - # Handle Input Data - inputData = paramsDict.get("InputData") - if inputData: - if isinstance(inputData, str): - inputData = inputData.replace(" ", "").split(";") - self._logVerbose(f"Setting input data to {inputData}", transID=transID, method=method) - seqDict["InputData"] = inputData - elif paramSeqDict.get("InputData") is not None: - self._logError("Invalid mixture of jobs with and without input data") - return S_ERROR(ETSDATA, "Invalid mixture of jobs with and without input data") - - for paramName, paramValue in paramsDict.items(): - if paramName not in ("InputData", "Site", "TargetSE"): - if paramValue: - self._logVerbose(f"Setting {paramName} to {paramValue}", transID=transID, method=method) - seqDict[paramName] = paramValue + inputData = self._handleInputsBulk(seqDict, paramsDict, transID) outputParameterList = [] if self.outputDataModule: @@ -235,7 +224,7 @@ def _prepareTasksBulk(self, transBody, taskDict, owner, ownerGroup): paramSeqDict.setdefault(pName, []).append(seq) for paramName, paramSeq in paramSeqDict.items(): - if paramName in ["JOB_ID", "PRODUCTION_ID", "InputData"] + outputParameterList: + if paramName in self.parametricSequencedKeys + outputParameterList: res = oJob.setParameterSequence(paramName, paramSeq, addToWorkflow=paramName) else: res = oJob.setParameterSequence(paramName, paramSeq) @@ -399,6 +388,28 @@ def _handleInputs(self, oJob, paramsDict): if not res["OK"]: self._logError(f"Could not set the inputs: {res['Message']}", transID=transID, method="_handleInputs") + def _handleInputsBulk(self, seqDict, paramsDict, transID): + """set job inputs (+ metadata)""" + method = "_handleInputsBulk" + if seqDict: + self._logVerbose(f"Setting job input data to {seqDict}", transID=transID, method=method) + + # Handle Input Data + inputData = paramsDict.get("InputData") + if inputData: + if isinstance(inputData, str): + inputData = inputData.replace(" ", "").split(";") + self._logVerbose(f"Setting input data {inputData} to {seqDict}", transID=transID, method=method) + seqDict["InputData"] = inputData + + for paramName, paramValue in paramsDict.items(): + if paramName not in ("InputData", "Site", "TargetSE"): + if paramValue: + self._logVerbose(f"Setting {paramName} to {paramValue}", transID=transID, method=method) + seqDict[paramName] = paramValue + + return inputData + def _handleRest(self, oJob, paramsDict): """add as JDL parameters all the other parameters that are not for inputs or destination""" transID = paramsDict["TransformationID"] diff --git a/src/DIRAC/TransformationSystem/Client/test/Test_Client_WorkflowTasks.py b/src/DIRAC/TransformationSystem/Client/test/Test_Client_WorkflowTasks.py index 06a7b255e53..20a404a11c4 100644 --- a/src/DIRAC/TransformationSystem/Client/test/Test_Client_WorkflowTasks.py +++ b/src/DIRAC/TransformationSystem/Client/test/Test_Client_WorkflowTasks.py @@ -3,6 +3,7 @@ # pylint: disable=protected-access,missing-docstring,invalid-name from unittest.mock import MagicMock + import pytest from DIRAC import gLogger, S_OK @@ -136,3 +137,28 @@ def test__handleDestination(mocker, paramsDict, expected): mocker.patch("DIRAC.TransformationSystem.Client.TaskManagerPlugin.getSitesForSE", side_effect=ourgetSitesForSE) res = wfTasks._handleDestination(paramsDict) assert sorted(res) == sorted(expected) + + +@pytest.mark.parametrize( + "seqDict, paramsDict, expected", + [ + ({}, {}, None), + ({"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"}, {}, None), + ( + {"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"}, + {"Site": "Site1", "JobType": "Sprucing", "TransformationID": 1}, + None, + ), + ( + {"Site": "Site1", "JobName": "Job1", "JOB_ID": "00000001"}, + {"Site": "Site1", "JobType": "Sprucing", "TransformationID": 1, "InputData": ["a1", "a2"]}, + ["a1", "a2"], + ), + # ({"a1": "aa1", "a2": "aa2", "a3": "aa3"}, {"b1": "bb1", "b2": "bb2", "b3": "bb3"}, {"b1": "bb1", "b2": "bb2"}, ["a1", "a2"]), + ], +) +def test__handleInputsBulk(mocker, seqDict, paramsDict, expected): + """Test the _handleInputsBulk method WorkflowTasks""" + mocker.patch("DIRAC.TransformationSystem.Client.TaskManagerPlugin.getSitesForSE", side_effect=ourgetSitesForSE) + res = wfTasks._handleInputsBulk(seqDict, paramsDict, transID=1) + assert res == expected