Skip to content

Commit bf67c7d

Browse files
committed
fixup! Implement mirror_partition (#6861)
1 parent 5050c40 commit bf67c7d

File tree

4 files changed

+37
-19
lines changed

4 files changed

+37
-19
lines changed

src/azul/plugins/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,8 @@ def count_bundles(self, source: SOURCE_SPEC) -> int:
676676
def count_files(self, source: SOURCE_SPEC) -> int:
677677
"""
678678
The total number of files in the given source. The source's prefix
679-
may be None.
679+
may be None, indicating that the source hasn't been partitioned yet and
680+
that this method should count all files in the source.
680681
"""
681682
raise NotImplementedError
682683

src/azul/plugins/repository/canned/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,16 @@ def count_bundles(self, source: CannedSourceRef) -> int:
176176

177177
def count_files(self, source: SimpleSourceSpec) -> int:
178178
staging_area = self.staging_area(source.name)
179-
prefix = '' if source.prefix is None else source.prefix.common
180-
return sum(
181-
1
182-
for descriptor in staging_area.descriptors.values()
183-
if descriptor.content['sha256'].startswith(prefix)
184-
)
179+
if source.prefix is None:
180+
return len(staging_area.descriptors)
181+
else:
182+
prefix = source.prefix.common
183+
assert not any(map(str.isupper, prefix)), source
184+
return sum(
185+
1
186+
for descriptor in staging_area.descriptors.values()
187+
if descriptor.content['sha256'].lower().startswith(prefix)
188+
)
185189

186190
def list_bundles(self,
187191
source: CannedSourceRef,
@@ -220,14 +224,15 @@ def fetch_bundle(self, bundle_fqid: CannedBundleFQID) -> CannedBundle:
220224
def list_files(self, source: CannedSourceRef, prefix: str) -> list[HCAFile]:
221225
self._assert_source(source)
222226
self._assert_partition(source, prefix)
227+
assert not any(map(str.isupper, prefix)), prefix
223228
staging_area = self.staging_area(source.spec.name)
224229
return [
225230
HCAFile.from_descriptor(descriptor.content,
226231
uuid=file_uuid,
227232
name=descriptor.content['file_name'],
228233
drs_uri=None)
229234
for file_uuid, descriptor in staging_area.descriptors.items()
230-
if descriptor.content['sha256'].startswith(prefix)
235+
if descriptor.content['sha256'].lower().startswith(prefix)
231236
]
232237

233238
def _construct_file_url(self, url: furl, file_name: str) -> furl:

src/azul/plugins/repository/tdr_anvil/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,16 @@ def _batch_uuid(self,
245245
self.bundle_uuid_version)
246246

247247
def count_files(self, source: TDRSourceSpec) -> int:
248+
if source.prefix is None:
249+
prefix = ''
250+
else:
251+
prefix = source.prefix.common
252+
assert not any(map(str.isupper, prefix)), source
248253
query = f'''
249254
SELECT COUNT(*) AS count
250255
FROM {backtick(self._full_table_name(source, 'anvil_file'))}
251-
''' + ('' if source.prefix is None else f'''
252-
WHERE starts_with(file_md5sum, {source.prefix.common!r})
253-
''')
256+
WHERE STARTS_WITH(LOWER(file_md5sum), {prefix!r})
257+
'''
254258
return one(self._run_sql(query))['count']
255259

256260
def count_bundles(self, source: TDRSourceSpec) -> int:
@@ -334,7 +338,7 @@ def list_files(self, source: TDRSourceRef, prefix: str) -> list[AnvilFile]:
334338
batch = self._get_batch(source.spec,
335339
'anvil_file',
336340
prefix,
337-
id_column='file_md5sum')
341+
key_column='file_md5sum')
338342
return [
339343
AnvilFile(uuid=ref.entity_id,
340344
name=row['file_name'],
@@ -562,13 +566,14 @@ def _get_batch(self,
562566
table_name: str,
563567
batch_prefix: str,
564568
*,
565-
id_column: str
569+
key_column: str
566570
) -> Iterable[tuple[EntityReference, BigQueryRow]]:
571+
assert batch_prefix.islower(), batch_prefix
567572
columns = self._columns(table_name)
568573
for row in self._run_sql(f'''
569574
SELECT {', '.join(sorted(columns))}
570575
FROM {backtick(self._full_table_name(source, table_name))}
571-
WHERE STARTS_WITH(LOWER({id_column}), {batch_prefix!r})
576+
WHERE STARTS_WITH(LOWER({key_column}), {batch_prefix!r})
572577
'''):
573578
ref = EntityReference(entity_type=table_name, entity_id=row['datarepo_row_id'])
574579
yield ref, row
@@ -579,7 +584,7 @@ def _get_bundle_batch(self,
579584
return self._get_batch(bundle_fqid.source.spec,
580585
bundle_fqid.table_name,
581586
bundle_fqid.batch_prefix,
582-
id_column='datarepo_row_id')
587+
key_column='datarepo_row_id')
583588

584589
def _bundle_entity(self, bundle_fqid: TDRAnvilBundleFQID) -> KeyReference:
585590
source = bundle_fqid.source

src/azul/plugins/repository/tdr_hca/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,18 @@ def count_bundles(self, source: TDRSourceSpec) -> int:
243243
return one(rows)['count']
244244

245245
def count_files(self, source: TDRSourceSpec) -> int:
246+
if source.prefix is None:
247+
prefix = ''
248+
else:
249+
prefix = source.prefix.common
250+
assert not any(map(str.isupper, prefix)), source
246251
query = ' UNION ALL '.join(
247252
f'''
248253
SELECT COUNT(*) AS count
249254
FROM {backtick(self._full_table_name(source, entity_type))}
250-
''' + ('' if source.prefix is None else f'''
251-
WHERE STARTS_WITH(JSON_EXTRACT_SCALAR(descriptor, "$.sha256"), {source.prefix.common!r})
252-
''')
255+
WHERE STARTS_WITH(LOWER(JSON_EXTRACT_SCALAR(descriptor, "$.sha256")),
256+
{prefix!r})
257+
'''
253258
for entity_type, entity_cls in api.entity_types.items()
254259
if entity_type.endswith('_file')
255260
)
@@ -278,11 +283,13 @@ def list_bundles(self,
278283
def list_files(self, source: TDRSourceRef, prefix: str) -> list[HCAFile]:
279284
self._assert_source(source)
280285
self._assert_partition(source, prefix)
286+
assert not any(map(str.isupper, prefix)), prefix
281287
rows = self._run_sql(' UNION ALL '.join(
282288
f'''
283289
SELECT {', '.join(TDRHCABundle.data_columns)}
284290
FROM {backtick(self._full_table_name(source.spec, entity_type))}
285-
WHERE STARTS_WITH(JSON_EXTRACT_SCALAR(descriptor, "$.sha256"), {prefix!r})
291+
WHERE STARTS_WITH(LOWER(JSON_EXTRACT_SCALAR(descriptor, "$.sha256")),
292+
{prefix!r})
286293
'''
287294
for entity_type, entity_cls in api.entity_types.items()
288295
if entity_type.endswith('_file')

0 commit comments

Comments
 (0)