diff --git a/genomic_address_service/classes/assign.py b/genomic_address_service/classes/assign.py index 4657d43..236e56c 100644 --- a/genomic_address_service/classes/assign.py +++ b/genomic_address_service/classes/assign.py @@ -28,7 +28,6 @@ def __init__(self,dist_file,membership_file,threshold_map,linkage_method,address self.assignments = {} self.nomenclature_cluster_tracker = {} self.query_ids = set() - if not linkage_method in self.avail_methods: self.status = False self.error_msgs.append(f'Provided {linkage_method} is not one of the accepted {self.avail_methods}') @@ -131,6 +130,15 @@ def process_memberships(self): lookup[code].append(id) self.memberships_lookup = lookup + def add_memberships_lookup(self,sample_id, address): + self.memberships_dict[sample_id] = ".".join([str(x) for x in address]) + for idx in range(0,len(address)): + code = ".".join([str(x) for x in address[0:idx+1]]) + if not code in self.memberships_lookup: + self.memberships_lookup[code] = list() + self.memberships_lookup[code].append(sample_id) + + def get_dist_summary(self,dists): min_dist = min(dists) ave_dist = mean(dists) @@ -182,46 +190,45 @@ def assign(self, n_records=1000,delim="\t"): query_addr = [None] * num_ranks if qid in self.memberships_dict: continue - for rid in dists[qid]: if rid == qid or rid not in self.memberships_dict: continue pairwise_dist = dists[qid][rid] thresh_idx = self.get_threshold_idx(pairwise_dist) thresh_value = self.thresholds[thresh_idx] - #save unnecessary work if thresh_value >= pairwise_dist: ref_address = self.memberships_dict[rid].split('.')[0:thresh_idx+1] alen = len(ref_address) for i in range(0,len(ref_address)): addr = ".".join(ref_address[0:alen-i]) + if addr not in self.memberships_lookup: continue addr_members = self.memberships_lookup[addr] addr_dists = [] for id in addr_members: - addr_dists.append(dists[qid][id]) + if id in dists[qid]: + addr_dists.append(dists[qid][id]) if len(addr_dists) == 0: continue summary = self.get_dist_summary(addr_dists) - is_eligible = True if self.linkage_method == 'complete' and summary['max'] > thresh_value: is_eligible = False elif self.linkage_method == 'average' and summary['mean'] > thresh_value: is_eligible = False - if is_eligible: for idx,value in enumerate(addr.split('.')): query_addr[idx] = value break - + thresh_value = self.thresholds[thresh_idx-(i+1)] + for idx,value in enumerate(query_addr): if value is None: query_addr[idx] = self.nomenclature_cluster_tracker[rank_ids[idx]] self.nomenclature_cluster_tracker[rank_ids[idx]]+=1 - break - self.memberships_dict[qid] = ".".join([str(x) for x in query_addr]) \ No newline at end of file + self.add_memberships_lookup(qid, query_addr) + #self.memberships_dict[qid] = ".".join([str(x) for x in query_addr]) \ No newline at end of file diff --git a/genomic_address_service/classes/reader.py b/genomic_address_service/classes/reader.py index 80c0e1f..e633d38 100644 --- a/genomic_address_service/classes/reader.py +++ b/genomic_address_service/classes/reader.py @@ -30,6 +30,8 @@ def guess_file_type(self,f): return file_type def guess_dist_type(self, fpath, ftype, delim="\t"): + header = [] + num_rows = 0 if ftype == 'text': header = get_file_header(fpath).split(delim) num_rows = get_file_length(fpath) @@ -53,7 +55,9 @@ def read_pd(self): continue qid = line[0] rid = line[1] + d = float(line[2]) + #print(f'{qid} {rid} {d}') if qid not in self.record_ids and len(self.dists) >= self.n_records: self.sort_distances() yield self.dists @@ -62,15 +66,10 @@ def read_pd(self): if qid not in self.record_ids: self.record_ids.add(qid) self.dists[qid] = {} - if self.filter: - if self.min_dist is not None: - if d < self.min_dist: - continue - if self.max_dist is not None: - if d > self.max_dist: - continue self.dists[qid][rid] = d self.sort_distances() + + yield self.dists def sort_distances(self): @@ -95,13 +94,6 @@ def read_matrix(self): for i in range(0,len(values)): rid = self.header[i] d = values[i] - if self.filter: - if self.min_dist is not None: - if d < self.min_dist: - continue - if self.max_dist is not None: - if d > self.max_dist: - continue self.dists[qid][rid] = d self.sort_distances() @@ -115,6 +107,7 @@ def read_data(self): self.header = next(self.file_handle).split(self.delim) elif ftype == 'parquet': self.file_handle = ParquetFile(self.fpath) + if ftype == 'text' and dist_type == 'pd': for chunk in self.read_pd(): if chunk is not None: @@ -127,7 +120,8 @@ def read_data(self): yield chunk if chunk is None: chunk = self.dists - yield chunk + + return chunk