@@ -11,7 +11,7 @@ import {
11
11
CheckResourcesCoverageSchema ,
12
12
GenerationStatus ,
13
13
generationStatusSchema ,
14
- GetGenerationDataInput ,
14
+ GetGenerationDataSchema ,
15
15
GetGenerationResourcesInput ,
16
16
} from '~/server/schema/generation.schema' ;
17
17
@@ -234,28 +234,39 @@ export const getGenerationData = async ({
234
234
query,
235
235
user,
236
236
} : {
237
- query : GetGenerationDataInput ;
237
+ query : GetGenerationDataSchema ;
238
238
user ?: SessionUser ;
239
239
} ) : Promise < GenerationData > => {
240
240
switch ( query . type ) {
241
241
case 'image' :
242
242
case 'video' :
243
- return await getMediaGenerationData ( { id : query . id , user } ) ;
243
+ return await getMediaGenerationData ( { id : query . id , user, generation : query . generation } ) ;
244
244
case 'modelVersion' :
245
- return await getModelVersionGenerationData ( { versionIds : [ query . id ] , user } ) ;
245
+ return await getModelVersionGenerationData ( {
246
+ versionIds : [ query . id ] ,
247
+ user,
248
+ generation : query . generation ,
249
+ } ) ;
246
250
case 'modelVersions' :
247
- return await getModelVersionGenerationData ( { versionIds : query . ids , user } ) ;
251
+ return await getModelVersionGenerationData ( {
252
+ versionIds : query . ids ,
253
+ user,
254
+ generation : query . generation ,
255
+ } ) ;
248
256
default :
249
257
throw new Error ( 'unsupported generation data type' ) ;
250
258
}
251
259
} ;
252
260
261
+ type ResourceType = 'generation' | 'all' ;
253
262
async function getMediaGenerationData ( {
254
263
id,
255
264
user,
265
+ generation,
256
266
} : {
257
267
id : number ;
258
268
user ?: SessionUser ;
269
+ generation : boolean ;
259
270
} ) : Promise < GenerationData > {
260
271
const media = await dbRead . image . findUnique ( {
261
272
where : { id } ,
@@ -294,7 +305,8 @@ async function getMediaGenerationData({
294
305
const versionIds = [
295
306
...new Set ( imageResources . map ( ( x ) => x . modelVersionId ) . filter ( isDefined ) ) ,
296
307
] ;
297
- const resources = await getGenerationResourceData ( { ids : versionIds , user } ) . then ( ( data ) =>
308
+ const fn = generation ? getGenerationResourceData : getResourceData ;
309
+ const resources = await fn ( { ids : versionIds , user } ) . then ( ( data ) =>
298
310
data . map ( ( item ) => {
299
311
const imageResource = imageResources . find ( ( x ) => x . modelVersionId === item . id ) ;
300
312
return {
@@ -379,15 +391,18 @@ async function getMediaGenerationData({
379
391
const getModelVersionGenerationData = async ( {
380
392
versionIds,
381
393
user,
394
+ generation,
382
395
} : {
383
396
versionIds : number [ ] ;
384
397
user ?: SessionUser ;
398
+ generation : boolean ;
385
399
} ) => {
386
400
if ( ! versionIds . length ) throw new Error ( 'missing version ids' ) ;
387
- const resources = await getGenerationResourceData ( { ids : versionIds , user } ) ;
401
+ const fn = generation ? getGenerationResourceData : getResourceData ;
402
+ const resources = await fn ( { ids : versionIds , user } ) ;
388
403
const checkpoint = resources . find ( ( x ) => x . baseModel === 'Checkpoint' ) ;
389
404
if ( checkpoint ?. vaeId ) {
390
- const [ vae ] = await getGenerationResourceData ( { ids : [ checkpoint . vaeId ] , user } ) ;
405
+ const [ vae ] = await fn ( { ids : [ checkpoint . vaeId ] , user } ) ;
391
406
if ( vae ) resources . push ( { ...vae , vaeId : undefined } ) ;
392
407
}
393
408
0 commit comments