Skip to contents
library(BSA)
library(here)
library(tidyverse)
library(foreach)
library(GenomicRanges)
library(BSgenome.CneoformansVarGrubiiKN99.NCBI.ASM221672v1)
library(SeqArray)
library(SeqVarTools)
library(Rsamtools)

kn99_genome = BSgenome.CneoformansVarGrubiiKN99.NCBI.ASM221672v1
input_vcf_tables = list.files(
  "/mnt/scratch/genotype_check/run_21_wgs_results/variants/filtered/tables",
  full.names = TRUE)
names(input_vcf_tables) = str_remove(basename(input_vcf_tables), "_freebayes_filtered.tsv")
input_vcf_tables = as.list(input_vcf_tables)

bsa_vcf = "/mnt/scratch/bsa/parent_strain_results/variants/filtered/1_freebayes_filtered.vcf.gz"

bsa_gds = here("data/parent_strains.gds")

meta_df = read_csv('/home/oguzkhan/Downloads/Doering_lab_WGS - 2024_0208 WGS(1).csv') %>%
  janitor::clean_names()

coordinate_meta_df = meta_df %>%
  mutate(marker1 = ifelse(marker %in% c('G418', 'NAT'), coordinates_if_applicable, '' )) %>%
  mutate(marker1 = ifelse(marker1=='N.A.', '', marker1)) %>%
  separate(marker1, c('marker1_start', 'marker1_end'), sep = ',', remove = FALSE) %>%
  select(-marker1) %>%
  replace_na(list(marker1_start='', marker1_end='')) %>%
  mutate(marker2 = ifelse(marker %in% c('G418 & NAT'), coordinates_if_applicable, '')) %>%
  mutate(marker2 = ifelse(marker2 == 'N.A.', '', marker2)) %>%
  # Assuming marker2 contains the string, now split it into separate columns
  separate(marker2, into = c("marker1_start_tmp", "marker1_end_tmp", "marker2_start", "marker2_end", "region"), sep = ",", extra = "merge", fill = "right") %>%
  mutate(marker1_start = ifelse(marker1_start=='', marker1_start_tmp, marker1_start),
         marker1_end = ifelse(marker1_end=='', marker1_end_tmp, marker1_end)) %>%
  select(-c(marker1_start_tmp,marker1_end_tmp)) %>%
  select(tube_number,marker,marker1_start,marker1_end,marker2_start,marker2_end,region) %>%
  # trim whitespace on all columns
  mutate_all(str_trim)
  
parse_cigar_for_coordinate <- function(positions, cigars, coordinate) {
  results <- lapply(seq_along(cigars), function(i) {
    pos <- positions[i]
    cigar <- cigars[i]
    
    # Split CIGAR string into operations and lengths
    ops <- gregexpr("[A-Z]", cigar)[[1]]
    lengths <- as.integer(unlist(regmatches(cigar, gregexpr("[0-9]+", cigar))))
    operations <- regmatches(cigar, gregexpr("[A-Z]", cigar))[[1]]
    
    aln_left <- 0
    aln_right <- 0
    reached_coordinate <- FALSE
    
    for (j in seq_along(operations)) {
      op <- operations[j]
      len <- lengths[j]
      
      # For alignment length calculation, consider match, deletion, and mismatch
      if(op %in% c("M", "D", "N", "=", "X")) {
        if (pos + len <= coordinate) {
          aln_left <- aln_left + len
        } else if (pos > coordinate) {
          aln_right <- aln_right + len
        } else {
          aln_left <- aln_left + (coordinate - pos)
          aln_right <- aln_right + (pos + len - coordinate - 1)
          reached_coordinate <- TRUE
          break
        }
        pos <- pos + len
      }
      
      if (reached_coordinate) break
    }
    
    return(c(length_aln_left = aln_left, length_aln_right = aln_right))
  })
  
  return(bind_rows(results))
}

# Define the function
extract_reads_by_alignment_score <- function(sample, bam_file, region, min_alignment_score) {
  # Parse the region
  chrom <- gsub(":.*", "", region)
  pos <- as.numeric(gsub(".*:", "", region))
  
  # Ensure only valid fields are requested
  valid_fields <- scanBamWhat()
  requested_fields <- c("qname", "pos", "cigar", "qual")
  what_fields <- requested_fields[requested_fields %in% valid_fields]
  
  # Define the scan parameters, ensuring only valid fields are included
  param <- ScanBamParam(
    which = GRanges(chrom, IRanges(start = pos, end = pos)),
    what = what_fields,
    tag = c("AS", "NM")
  )
  
  # Read the BAM file
  aln <- scanBam(bam_file, param = param)[[1]]
  
  match_length_left_right <- parse_cigar_for_coordinate(aln$pos, aln$cigar, pos)
  
  tibble(sample=sample,
         chr=chrom,
         pos=pos,
         qname = aln$qname,
         cigar=aln$cigar,
         align_qual=aln$tag$AS,
         edit_dist=aln$tag$NM) %>%
    bind_cols(match_length_left_right)
}

# Example usage
bam_file <- "/mnt/scratch/genotype_check/run_21_wgs_result_markers/alignment/TH0003_sorted_markdups_tagged.bam"
region <- "TH_C8_NAT_IR6:301"
min_alignment_score <- 0  # Example threshold
reads_df <- extract_reads_by_alignment_score('sample1',bam_file, region, min_alignment_score)

bam_list=list.files("/mnt/scratch/genotype_check/run_21_wgs_result_markers/alignment",
                    pattern="*bam$", full.names=TRUE)
bam_df = tibble(bam=bam_list, tube_number=str_remove(basename(bam_list),'_sorted_markdups_tagged.bam'))
marker_check_df = coordinate_meta_df %>%
  left_join(bam_df) %>%
  select(tube_number, bam, marker1_start, marker1_end, marker2_start, marker2_end) %>%
  pivot_longer(-c(tube_number,bam)) %>%
  filter(complete.cases(.)) %>%
  filter(value != '') %>%
  select(-name) %>%
  dplyr::rename(coordinate=value)

get_reads_over_region_per_sample <- function(tube_number, bam, coordinate, ...) {
  message(sprintf('working on %s', tube_number))
  
  extract_reads_by_alignment_score(tube_number, bam, coordinate, 0)
}

res = pmap(marker_check_df, get_reads_over_region_per_sample)

res_out = res %>% bind_rows() %>%
  group_by(sample)

res_out_names = group_keys(res_out)$sample

res_out_list = res_out %>%
  group_split()

# map2(res_out_list, res_out_names,
#      ~write_csv(.x,
#                 file.path("/mnt/scratch/genotype_check/run_21_wgs_results/marker_check",
#                           paste0(.y,'.csv'))))

res_out %>%
  group_by(sample,chr,pos) %>%
  tally() %>%
  ungroup() %>%
  group_by(sample) %>%
  filter(n() %in% c(1,3))

read_in_vcf_tables = function(path){
  read_tsv(path) %>% 
    dplyr::rename(depth=DP) %>%
    dplyr::rename_with(.fn = ~ gsub("^(.*?\\.)", "", .x), cols=everything()) %>%
    select(-c(DP, AD, GT)) %>%
    dplyr::rename(alt_depth = AO,
                  ref_depth = RO) %>%
    dplyr::rename(CHR=CHROM)
}

input_vcf_tables_df = map(input_vcf_tables, read_in_vcf_tables) %>%
  bind_rows(.id = "sample")
tmp_dir = tempdir()
bsa_df = vcf_to_qtlseqr_table(bsa_vcf, tmp_dir,
                              parent_ref_sample = "KN99a",
                              parent_alt_sample="C8",
                              parent_filter=FALSE,
                              overwrite = TRUE)

parent_df = bsa_df %>%
  filter(sample %in% c('KN99a2', 'C8Merged')) %>%
  mutate(sample = case_when(sample == 'KN99a2' ~ 'KN99a',
                            sample == 'C8Merged' ~ 'C8')) %>%
  select(CHR,POS, ALT1, RealDepth, Alternative1, QUAL, sample,
         genotype, Alt1_percentage, Reference, Ref_percentage)

kn99a_df = parent_df %>%
  filter(sample == 'KN99a') %>%
  pivot_wider(names_from = sample,
              values_from = c(RealDepth, Alternative1, 
                              Reference, QUAL, genotype,
                              Alt1_percentage, Ref_percentage)) %>%
  rename_with(.fn = ~ gsub("_KN99a$", "", .x), .cols = everything()) %>%
  dplyr::rename(parent_genotype=genotype) %>%
  mutate(parent_genotype_strict = ifelse(Alt1_percentage>=.9, 1, 0)) %>%
  mutate(confidently_reference =
           ifelse((parent_genotype==0 & parent_genotype_strict==0),
                  TRUE, FALSE)) %>%
  filter(RealDepth >= 10) %>%
  arrange(desc(Alt1_percentage)) %>%
  mutate(kn99a=TRUE)
  

c8_df = parent_df %>%
  filter(sample == 'C8') %>%
  pivot_wider(names_from = sample,
              values_from = c(RealDepth, Alternative1, 
                              Reference, QUAL, genotype,
                              Alt1_percentage, Ref_percentage)) %>%
  rename_with(.fn = ~ gsub("_C8$", "", .x), .cols = everything()) %>%
  dplyr::rename(parent_genotype=genotype) %>%
  mutate(parent_genotype_strict = ifelse(Alt1_percentage>=.9, 1, 0)) %>%
  mutate(confidently_reference =
           ifelse((parent_genotype==0 & parent_genotype_strict==0),
                  TRUE, FALSE)) %>%
  filter(RealDepth >= 10) %>%
  arrange(desc(Alt1_percentage)) %>%
  mutate(c8 = TRUE)

summarize_genomic_origin = function(vcf_df, region_chr=NULL, region_start=NULL, region_end=NULL){
  
  grouped_vcf_df = vcf_df %>%
    mutate(low_depth = depth < 10) %>%
    left_join(kn99a_df %>%
                filter(parent_genotype_strict==1) %>%
                select(CHR,POS, ALT1, kn99a) %>%
                dplyr::rename(ALT=ALT1)) %>%
    left_join(c8_df %>%
                filter(parent_genotype_strict==1) %>%
                select(CHR,POS,ALT1, c8) %>%
                dplyr::rename(ALT=ALT1)) %>%
    replace_na(list(kn99a=FALSE,c8=FALSE)) %>%
    mutate(var_origin=case_when((kn99a & c8)~'both',
                                (kn99a & !c8)~'kn99a',
                                (!kn99a & c8) ~ 'C8',
                                .default='denovo')) %>%
    group_by(var_origin)
  
  if(!is.null(region_start)){
    grouped_vcf_df %>%
      ungroup() %>%
      filter(CHR == region_chr & POS >= region_start & POS <= region_end) %>%
      group_by(var_origin) %>%
      reframe(region_percent = n()/nrow(.)*100,
              region_parent_total = n(),
              region_total_variants = nrow(.))
  } else{
    grouped_vcf_df %>%
      reframe(genome_wide_percent = n()/nrow(.)*100,
              genome_wide_parent_total = n(),
              total_variants = nrow(.))
  }
}

input_vcf_tables_df_list = input_vcf_tables_df %>%
  filter(!startsWith(sample, "DGK")) %>%
  group_by(sample) %>%
  droplevels()

wgs_var_summaries = map(group_split(input_vcf_tables_df_list), summarize_genomic_origin)
names(wgs_var_summaries) = group_keys(input_vcf_tables_df_list)$sample

wgs_var_summaries_df = wgs_var_summaries %>%
  bind_rows(.id = "sample")

unusual_samples = wgs_var_summaries_df %>%
  filter(!var_origin %in% c('denovo', 'both'),
         genome_wide_percent < 20) %>%
  pull(sample)

# map2(wgs_var_summaries, group_keys(input_vcf_tables_df_list)$sample,
#      ~write_csv(.x, file.path('/mnt/scratch/genotype_check/run_21_wgs_results/wgs_variant_summary',
#                           paste0(.y, '.csv'))))

wgs_var_summaries_df %>%
  filter(sample %in% unusual_samples) %>%
  left_join(coordinate_meta_df %>% select(tube_number, marker), by=c('sample'='tube_number')) %>%
  View()

# regions samples
region_df = coordinate_meta_df %>%
  filter(!is.na(region)) %>%
  select(tube_number, region) %>%
  dplyr::rename(sample=tube_number) %>%
  separate(region, into=c('chr', 'coord'), sep=':') %>%
  separate(coord, into = c('start', 'end'), sep='-')

region_stats = function(vcf_df){
  sample = vcf_df %>% pull(sample) %>% unique()
  
  region = region_df %>%
    filter(sample == sample)
  
  region
  
  # summarize_genomic_origin(vcf_df,
  #                          region_chr=region$chr,
  #                          region_start=region$start,
  #                          region_end=region$end)
  
}

region_wgs_var_summaries = map(group_split(input_vcf_tables_df_list), region_stats)