Pytorch Breeding
Project description
ChewC
from chewc.core import *
In short, this will be a GPU-enabled stochastic simulation for breeding programs with an emphasis on cost-benefit-analysis for novel breeding tools and creating a suitable interface for RL agents.
We will also incorporate an emphasis on budget and costs associated with each action to manage long-term breeding budgets. As well as model theoretical tools in the plant breeder’s toolbox. e.g.
a treatment which increases crossover rates
a treatment which reduces flowering time
a treatment which enables gene drive at select loci
Each treatment will cost $$ ultimately helping guide the implementation in real-world breeding programs.
Install
pip install chewc
How to use
First, define the genome of your crop
import random
ploidy = 2
number_chromosomes = 10
loci_per_chromosome = 100
genetic_map = create_random_genetic_map(number_chromosomes,loci_per_chromosome)
crop_genome = Genome(ploidy, number_chromosomes, loci_per_chromosome, genetic_map)
n_founders = 500
founder_pop = create_random_founder_pop(crop_genome , n_founders)
simparam = SimParam
simparam.founder_pop = founder_pop
simparam.genome = crop_genome
#add a single additive trait
qtl_loci = 20
qtl_map = select_qtl_loci(qtl_loci,simparam.genome)
ta = TraitA(qtl_map,simparam,0, 1)
ta.sample_initial_effects()
ta.scale_genetic_effects()
ta.calculate_intercept()
years = 50
current_pop = founder_pop
pmean = []
pvar = []
for i in range(years):
#phenotype current pop
TOPK = 2
new_pop=[]
pheno = ta.phenotype(current_pop,h2=.14)
topk = torch.topk(pheno,TOPK).indices
for i in range(200):
sampled_indices = torch.multinomial(torch.ones(topk.size(0)), 2, replacement=False)
sampled_parents = topk[sampled_indices]
m,f = current_pop[sampled_parents[0]], current_pop[sampled_parents[1]]
new_pop.append(make_cross(simparam.genome.genetic_map, m, f))
current_pop = torch.stack(new_pop)
pmean.append(ta.calculate_genetic_values(current_pop).mean())
pvar.append(ta.calculate_genetic_values(current_pop).var())
pmean_normalized = torch.tensor(pmean) / max(pmean)
pvar_normalized = torch.tensor(pvar) / max(pvar)
plt.scatter(range(len(pmean_normalized)), pmean_normalized)
plt.scatter(range(len(pvar_normalized)), pvar_normalized)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.