@@ -318,6 +318,7 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
318
318
self .table_keywords = kwargs .pop ("table_keywords" , False )
319
319
self .column_keywords = kwargs .pop ("column_keywords" , False )
320
320
self .table_proxy = kwargs .pop ("table_proxy" , False )
321
+ self .context = kwargs .pop ("context" , None )
321
322
322
323
if len (kwargs ) > 0 :
323
324
raise ValueError (f"Unhandled kwargs: { kwargs } " )
@@ -359,8 +360,10 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
359
360
coords = {"ROWID" : rowid }
360
361
361
362
attrs = {DASKMS_PARTITION_KEY : ()}
362
-
363
- return Dataset (variables , coords = coords , attrs = attrs )
363
+ dataset = Dataset (variables , coords = coords , attrs = attrs )
364
+ return self .postprocess_dataset (
365
+ dataset , table_proxy , exemplar_row , orders , self .chunks [0 ], short_table_name
366
+ )
364
367
365
368
def _group_datasets (self , table_proxy , groups , exemplar_rows , orders ):
366
369
_ , t , s = table_path_split (self .canonical_name )
@@ -420,10 +423,64 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
420
423
group_id = [gid .item () for gid in group_id ]
421
424
attrs .update (zip (self .group_cols , group_id ))
422
425
423
- datasets .append (Dataset (group_var_dims , attrs = attrs , coords = coords ))
426
+ dataset = Dataset (group_var_dims , attrs = attrs , coords = coords )
427
+ dataset = self .postprocess_dataset (
428
+ dataset , table_proxy , exemplar_row , order , group_chunks , array_suffix
429
+ )
430
+ datasets .append (dataset )
424
431
425
432
return datasets
426
433
434
+ def postprocess_dataset (
435
+ self , dataset , table_proxy , exemplar_row , order , chunks , array_suffix
436
+ ):
437
+ if not self .context or self .context != "ms" :
438
+ return dataset
439
+
440
+ # Fixup any non-standard columns
441
+ # with dimensions like chan and corr
442
+ try :
443
+ chan = dataset .sizes ["chan" ]
444
+ corr = dataset .sizes ["corr" ]
445
+ except KeyError :
446
+ return dataset
447
+
448
+ schema_updates = {}
449
+
450
+ for name , var in dataset .data_vars .items ():
451
+ new_dims = list (var .dims [1 :])
452
+
453
+ unassigned = {"chan" , "corr" }
454
+
455
+ for dim , dim_name in enumerate (var .dims [1 :]):
456
+ # An automicatically assigned dimension name
457
+ if dim_name == f"{ name } -{ dim + 1 } " :
458
+ if dataset .sizes [dim_name ] == chan and "chan" in unassigned :
459
+ new_dims [dim ] = "chan"
460
+ unassigned .discard ("chan" )
461
+ elif dataset .sizes [dim_name ] == corr and "corr" in unassigned :
462
+ new_dims [dim ] = "corr"
463
+ unassigned .discard ("corr" )
464
+
465
+ new_dims = tuple (new_dims )
466
+ if var .dims [1 :] != new_dims :
467
+ schema_updates [name ] = {"dims" : new_dims }
468
+
469
+ if not schema_updates :
470
+ return dataset
471
+
472
+ return dataset .assign (
473
+ ** _dataset_variable_factory (
474
+ table_proxy ,
475
+ schema_updates ,
476
+ list (schema_updates .keys ()),
477
+ exemplar_row ,
478
+ order ,
479
+ chunks ,
480
+ array_suffix ,
481
+ )
482
+ )
483
+
427
484
def datasets (self ):
428
485
table_proxy = self ._table_proxy_factory ()
429
486
0 commit comments