Skip to content

Commit

Permalink
Fix gpt annot functions [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
bschilder committed Mar 11, 2024
1 parent 3a3d953 commit 02be201
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 45 deletions.
69 changes: 54 additions & 15 deletions R/gpt_annot_check.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#' generated by \link[HPOExplorer]{search_hpo}.
#' These will be used as the ground truth when trying to identify
#' true positive annotations.
#' @param response_map A named list of responses to map onto.
#' Used for standardisation of responses.
#' @param pos_values Positive values.
#' @param neg_values Negative values.
#' @param verbose Print messages.
#' @returns Named list
#'
Expand All @@ -16,12 +20,15 @@
#' checks <- gpt_annot_check()
gpt_annot_check <- function(annot = gpt_annot_read(),
query_hits = search_hpo(),
response_map = list(
no="never",
yes="always"
),
pos_values=c("sometimes","often","always"),
neg_values = c("never","rarely"),
verbose = TRUE
){

# path="~/Downloads/gpt_hpo_annotations.csv"
pheno_count <- NULL;

#### Proportion of HPO_IDs annotated before/after chatGPT ####
# hpo <- get_hpo()
# prior_ids <- unique(HPOExplorer::hpo_modifiers$hpo_id)
Expand All @@ -30,15 +37,33 @@ gpt_annot_check <- function(annot = gpt_annot_read(),
# length(prior_ids)/length(hpo@terms)
# length(new_ids)/length(hpo@terms)
#### Check annotation consistency ####
nm <- names(annot)[!grepl("hpo_name|justification|hpo_id",names(annot),
nm <- names(annot)[!grepl("hpo_name|justification|hpo_id|hpo_name|pheno_count",names(annot),
ignore.case = TRUE)]
counts <- table(tolower(unlist(annot[,nm,with=FALSE])), useNA = "always")
neg_values <- c("never","no")
opts <- unlist(sapply(annot[,nm,with=FALSE], unique)) |> unique()
response_counts <- table(tolower(unlist(annot[,nm,with=FALSE])), useNA = "always")
# opts <- unlist(sapply(annot[,nm,with=FALSE], unique)) |> unique()

#### Standardise responses ####
for(n in nm){
annot[,(n):=tolower(get(n))]
annot[,(n):=ifelse(get(n) %in% names(response_map),response_map[[get(n)]], get(n)), by=.I]
annot[,(n):=ifelse(get(n) %in% c(pos_values,neg_values),get(n),NA), by=.I]
}
response_counts_standard <- table(tolower(unlist(annot[,nm,with=FALSE])), useNA = "always")

#### Compute number of non-negative answers within each column####
annot_mean <- annot[pheno_count>1][,lapply(.SD,function(x){
mean(!tolower(x) %in% neg_values)
}),.SDcols=nm,by="hpo_name"]
if(nrow(annot[pheno_count>1])==0){
messager("No duplicate phenotypes to check consistency for.")
annot_mean <- NULL
}else {
#### Check for relaxed consistency (only distguish between positive and negative )####
annot_mean <- annot[pheno_count>1][,lapply(.SD,function(x){
mean(!na.omit(x) %in% neg_values)
}),.SDcols=nm,by="hpo_name"]
#### Check for stringent consistency ####
annot_stringent_mean <- annot[pheno_count>1][,lapply(.SD,function(x){
data.table::uniqueN(x)==1
}),.SDcols=nm,by="hpo_name"]
}
#### Check ontology classifications #####
annot_check <- lapply(seq(nrow(annot)), function(i){
r <- annot[i,]
Expand All @@ -48,7 +73,7 @@ gpt_annot_check <- function(annot = gpt_annot_read(),
names(query_hits)),
function(x){
if(r$hpo_id %in% query_hits[[x]]){
!tolower(r[,x,with=FALSE][[1]]) %in% neg_values
tolower(r[,x,with=FALSE][[1]]) %in% pos_values
} else {
NA
}
Expand All @@ -57,8 +82,16 @@ gpt_annot_check <- function(annot = gpt_annot_read(),
}) |> data.table::rbindlist()

#### Compute consistency within each column ####
annot_consist <- sapply(annot_mean[,-1],
function(x)sum(x%in%c(0,1)/nrow(annot_mean)))
if(!is.null(annot_mean)){
consistency_count <- sapply(annot_mean[,-c("hpo_name")],
function(x)nrow(annot_mean))
consistency_rate <- sapply(annot_mean[,-c("hpo_name")],
function(x)sum(x%in%c(0,1)/nrow(annot_mean)))
consistency_stringent_rate <- sapply(annot_stringent_mean[,-c("hpo_name")],
function(x)sum(x)/nrow(annot_mean))
} else {
consistency_count <- consistency_rate <- consistency_stringent_rate <- NULL
}
### Proportion of rows where annotation is not NA
checkable_rate <- sapply(
annot_check[,names(query_hits),with=FALSE],
Expand All @@ -78,12 +111,18 @@ gpt_annot_check <- function(annot = gpt_annot_read(),
checks <- list(
annot=annot,
annot_mean=annot_mean,
annot_consist=annot_consist,
consistency_count=consistency_count,
consistency_rate=consistency_rate,
consistency_stringent_count=consistency_count,
consistency_stringent_rate=consistency_stringent_rate,
annot_check=annot_check,
checkable_rate=checkable_rate,
checkable_count=checkable_count,
true_pos_count=checkable_count,
true_pos_rate=true_pos_rate,
false_neg_rate=false_neg_rate
false_neg_rate=false_neg_rate,
response_counts=response_counts,
response_counts_standard=response_counts_standard
)
#### Plot ####
checks[["plot"]] <- gpt_annot_check_plot(checks=checks)
Expand Down
118 changes: 101 additions & 17 deletions R/gpt_annot_check_plot.R
Original file line number Diff line number Diff line change
@@ -1,29 +1,113 @@
gpt_annot_check_plot <- function(checks,
items = c("annot_consist",
"checkable_count",
#"checkable_rate",
"true_pos_rate")){
items = c("consistency_count","consistency_rate",
"consistency_stringent_count","consistency_stringent_rate",
# "checkable_count","checkable_rate",
"true_pos_count","true_pos_rate"),
metric_types=c("Rate","Count")[1],
scales = "free"){
requireNamespace("ggplot2")

annotation <- value <- metric <- NULL;
annotation <- metric <- metric_category <- metric_type <- NULL;
check_df <- lapply(checks[items],
data.table::as.data.table, keep.rownames = TRUE) |>
data.table::rbindlist(idcol = "metric") |>
data.table::setnames(c("metric","annotation","value"))
check_df$metric <- factor(check_df$metric, levels = items, ordered = TRUE)

ggplot2::ggplot(check_df[annotation!="pheno_count"],
ggplot2::aes(x=annotation, y=value, fill=metric)) +
ggplot2::geom_bar(stat="identity") +
ggplot2::facet_grid(rows="metric",
scales = "free_y",) +
check_df$metric <- factor(check_df$metric,
levels = items,
ordered = TRUE)
check_df$annotation <- factor(check_df$annotation,
levels = unique(check_df$annotation),
ordered = TRUE)
# check_df[,row_count:=.N[!is.na(value)], by=c("metric","annotation")]
check_df[,metric_type:=ifelse(grepl("count",metric),
"Count",ifelse(grepl("rate",metric),"Rate",NA))]
check_df[,metric_category:=gsub("_count|_rate","",metric)]
check_df$metric_type <- factor(check_df$metric_type, ordered = TRUE)
check_df <- check_df[annotation!="pheno_count"]
combos <- expand.grid(
metric_category=check_df$metric_category,
metric_type=check_df$metric_type,
annotation=check_df$annotation
) |> unique() |> `rownames<-`(NULL)
check_df2 <- merge(check_df,
combos,
by=c("metric_category","metric_type","annotation"),
all.y=TRUE)
# check_df[,label:=paste0("n=",value[metric_type=="Count"]),
# by=c("metric_category","annotation")]
check_df[,n:=value[metric_type=="Count"],
by=c("metric_category","annotation")]

plt <- ggplot2::ggplot(check_df[metric_type %in% metric_types],
ggplot2::aes(x=annotation, y=value,
fill=n,
label=round(value,2))) +
ggplot2::geom_bar(stat="identity", show.legend = TRUE) +
# ggplot2::scale_fill_viridis_d(drop = FALSE, end = .8 )+
ggplot2::scale_fill_viridis_c()+
ggplot2::scale_x_discrete(drop = scales %in% c("free","free_x")) +
ggplot2::facet_grid(facets=metric_type~metric_category,
scales = scales) +
ggplot2::theme_bw() +
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1),
strip.background = ggplot2::element_rect(fill = "grey20"),
strip.text = ggplot2::element_text(color = "white")) +
ggplot2::labs(title = "GPT annotation validation",
x = "Annotation column",
y = "Value",
fill = "Metric")
strip.background = ggplot2::element_rect(fill = "grey20"),
strip.text = ggplot2::element_text(color = "white")) +
ggplot2::labs(x = "Annotation",
y = "Value",
fill = "N phenotypes") +
ggplot2::geom_label(fill='white')
# ggplot2::geom_text(angle=90, size=3, hjust=1.2)
# plt

if(scales %in% c("free","free_x")){
plt <- plt+
ggplot2::geom_vline(data = check_df2[is.na(value)],
ggplot2::aes(xintercept=annotation),
# label="N/A",
color="red", linetype="solid", linewidth=3, alpha=0.25) +
ggplot2::geom_text(data = check_df2[is.na(value)][,label:="N/A"],
ggplot2::aes(x=annotation, label=label, y=2),
position=ggplot2::position_fill(vjust = 0.1),
color="red", alpha=0.5, angle=90)
}
# count_df <- check_df[metric_type=="Count",,drop=FALSE]
# rate_df <- check_df[metric_type=="Rate",,drop=FALSE][annotation!="pheno_count"]
# plts <- list()
# make_plot <- function(df,
# y_lab=NULL){
# p <- ggplot2::ggplot(df,
# ggplot2::aes(x=annotation, y=value,
# fill=metric_type)) +
# ggplot2::geom_bar(stat="identity", show.legend = FALSE) +
# ggplot2::scale_fill_viridis_d(drop=FALSE, end = .8 )+
# ggplot2::scale_x_discrete(drop = FALSE) +
# ggplot2::facet_grid(rows="metric") +
# ggplot2::theme_bw() +
# ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1),
# strip.background = ggplot2::element_rect(fill = "grey20"),
# strip.text = ggplot2::element_text(color = "white")) +
# ggplot2::labs(x = "Annotation",
# y = y_lab,
# fill = "Metric")
# }
# if(nrow(count_df)>0){
# plts[["count"]] <- make_plot(count_df,y_lab = "Count")
# }
# if(nrow(rate_df)>0){
# plts[["rate"]] <- make_plot(rate_df,y_lab = "Rate")
# }
# if(length(plts)>1){
# plts[[1]] <- plts[[1]] +
# ggplot2::theme(axis.text.x = ggplot2::element_blank())
# }
#
# plt <- patchwork::wrap_plots(plts, ncol = 1,
# guides = "collect") +
# patchwork::plot_layout(axis_titles = "collect",
# axes = "collect")
return(list(
data=check_df,
plot=plt
))
}
14 changes: 7 additions & 7 deletions R/gpt_annot_codify.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ gpt_annot_codify <- function(annot = gpt_annot_read(),
reset_tiers_dict=FALSE,
filters=list()
){
# res <- gpt_annot_check(path="~/Downloads/gpt_hpo_annotations.csv")
# annot <- res$annot
severity_score_gpt <- congenital_onset <- hpo_name <- hpo_id <- NULL;
severity_score_gpt <- hpo_name <- NULL;

d <- data.table::copy(annot)
if(isTRUE(reset_tiers_dict)) tiers_dict <- lapply(tiers_dict,function(x){1})
Expand Down Expand Up @@ -70,12 +68,14 @@ gpt_annot_codify <- function(annot = gpt_annot_read(),
unlist(code_dict[tolower(x)])}),.SDcols = cols, by=c("hpo_id","hpo_name")]
d_weighted <- data.table::as.data.table(
lapply(stats::setNames(cols,cols),
function(co){d_coded[[co]]*
((max(unlist(tiers_dict))+1)-tiers_dict[[co]]) })
)[,hpo_id:=d_coded$hpo_id][,severity_score_gpt:=(
function(co){
d_coded[[co]]*
((max(unlist(tiers_dict))+1)-tiers_dict[[co]])
})
)[,hpo_name:=d_coded$hpo_name][,severity_score_gpt:=(
rowSums(.SD,na.rm = TRUE)/max_score*100),
.SDcols=cols][d_coded[,-cols,with=FALSE],
on="hpo_id"] |>
on="hpo_name"] |>
data.table::setorderv("severity_score_gpt",-1, na.last = TRUE) |>
unique()
#### Order hpo_names by severity_score_gpt #####
Expand Down
10 changes: 6 additions & 4 deletions R/gpt_annot_read.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' Read in phenotype annotations generated by GPT and
#' do some initial preprocessing (e.g. adding HPO IDs).
#' @inheritParams main
#' @inheritParams make_
#' @param save_path Path to annotations CSV file.
#' If the file does not exist, the data will be downloaded from GitHub.
#' @param force_new If \code{TRUE}, the data will be downloaded from GitHub
Expand All @@ -13,11 +14,13 @@
#'
#' @export
#' @examples
#' annot <- gpt_annot_read()
#' gpt_annot <- gpt_annot_read()
gpt_annot_read <- function(save_path=file.path(
KGExplorer::cache_dir(package = "HPOExplorer"),
"gpt4_hpo_annotations.csv"
),

phenotype_to_genes = load_phenotype_to_genes(),
force_new=FALSE,
hpo=get_hpo(),
include_nogenes=TRUE,
Expand All @@ -37,9 +40,8 @@ include_nogenes=TRUE,
data.table::setnames(d,"phenotype","hpo_name")
d <- add_hpo_id(d, hpo = hpo)
#### Check phenotype names ####
annot <- load_phenotype_to_genes(verbose = verbose)
d <- merge(d,
unique(annot[,c("hpo_id","hpo_name")]),
unique(phenotype_to_genes[,c("hpo_id","hpo_name")]),
all.x = TRUE,
by=c("hpo_name","hpo_id"))
d <- data.frame(d)
Expand All @@ -57,7 +59,7 @@ include_nogenes=TRUE,
"phenotypes.",
v=verbose)
# phenotype_miss_rate <-
# length(d$phenotype[!d$phenotype %in% annot$hpo_name]) /
# length(d$phenotype[!d$phenotype %in% phenotype_to_genes$hpo_name]) /
# length(d$phenotype)
return(d)
}
2 changes: 1 addition & 1 deletion R/plot_evidence.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ plot_evidence <- function(metric="evidence_score_sum",
evidence_score_sum_min=min(get(metric),na.rm=TRUE),
evidence_score_sum_max=max(get(metric),na.rm=TRUE),
evidence_score_sum_mean=mean(get(metric),na.rm=TRUE),
evidence_score_sum_sd=sd(get(metric),na.rm=TRUE)),
evidence_score_sum_sd=stats::sd(get(metric),na.rm=TRUE)),
by=c("hpo_id","gene_symbol")]|>
data.table::setorderv(metric_sd, -1,na.last = TRUE)
h5 <- plot_hist(gcc_phenotype_agg_sd[!is.na(get(metric_sd))],
Expand Down
10 changes: 10 additions & 0 deletions man/gpt_annot_check.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/gpt_annot_read.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 02be201

Please sign in to comment.