Epistatic transformer for jointly modeling fixed-order specific epistasis and nonspecific epistasis.

a. Overall model architecture. The input amino acid token xl is first used to generate position-specific amino acid embeddings z0. The embeddings are then passed through M layers of modified multi-head attention (MHA) such that the output embeddings contain specific epistasis up to order 2M. Left: Details for a single MHA layer. Hidden states zm from the previous layer is passed through linear layers to generate the query (Q) and key (K) tensors, while the raw embeddings Z0 are used directly as the value (V) tensor. Attention weights are calculated by taking scaled dot products between Q and K, which are used to generate the final output of this layer. This bypassing of the raw embedding Z0, together with the removal of LayerNorm, softmax operation on the attention weights, and the feedforward layer, allow us to model only interactions among up to 2M sites when using M MHA layers. b. Automatic hyperparameter search scheme using Optuna [48]

Epistatic transformer is able to capture the genetic architecture of simulated sequence-function relationships.

Data was simulated for a binary fitness landscape with 13 sites, with specific epistasis among up to 8 sites. a. Test R2 can be used to estimate cumulative variance components. Bar plot shows the cumulative variance explained by each order of interaction (proportion of variance due to epistasis up to a given order). Error plot shows the test R2 for additive, pairwise, four-way, and eight-way interaction models using epistatic transformer to 90% of simulated data. Error bars correspond to one standard deviation with 5 random train-test splits. b. Epistatic transformer recapitulates true variance components. Bar plots correspond to variance components(proportion of variance explained by interactions of a given order) for the simulated landscape, the pairwise, 4-way, and 8-way models. Note that pairwise model only contains interaction up to second order, and 4th-order model only contains interaction among up to 4 positions, while 8th-order model model contains interaction among up to 8 positions with variance components aligning with the ground truth. c. Epistatic transformer model becomes increasingly complex with longer training. The three curves correspond to the variance component decomposition of the full landscape inferred by the 8th-order model at each training epoch. Blue: variance components for additive and pairwise interactions. Orange: variance components for three-way and four-way interactions. Green: variance components for interaction order higher than four.

Combinatorial protein mutagenesis datasets used in this paper.

Importance of pairwise and higher-order specific epistasis in 10 experimental protein-sequence function datasets.

For each dataset, pairwise, 4th-order and 8th-order models were fit using epistatic transformer with 1, 2, and 3 layers of epistatic multihead attention (MHA), along with an additive model. All models contain a final nonlinear activation function mapping a scalar value to the measurement scale for modeling non-specific epistasis. Models were fit to 80% of training data generated by randomly sampling all available data and evaluated on random test genotypes. In each panel, the number on the upper right corner is equal to the proportion of total variance in the test data explained by all orders of specific epistasis, equal to the difference between the R2 of the 8-th order epistatic transformer model and the additive model. Importance of epistatic interactions of different orders is measured by percent epistatic variance, equal to the gain in R2 by fitting an additional layer of MHA, normalized by . For example, the percent epistatic variance due to pairwise interactions is equal to the difference in R2 between the pairwise and additive model divided by , and the percent epistatic variance due to 3-way and 4-way interactions is equal to the difference in R2 by the 4-th order model and the pairwise divided by . Error bars represent 1 standard deviation calculated from 3 replicates.

Improvement in prediction accuracy in higher-order models is due to specific epistasis.

Scatter plots show the observed phenotypes (y) vs. latent model predictions (ϕ), or the final model predictions (ŷ)) for the test genotypes, for the pairwise and 8th-order epistatic transformer models fit to the GRB-1 and AAV2-Capsid datasets. Models were fit to 80% of randomly sampled training data.

Importance of higher-order epistasis in predicting phenotypes for distant genotypes in the AAV2-Capsid (a) and the cgreGFP (b) datasets.

For each dataset, genotypes are binned to discrete distance classes by their mean Hamming distances to the training data, which consist of 20% of randomly sampled genotypes. For both datasets, we retain only distance classes where the additive model has a test R2 > 0.3. Top panel: Distribution of mean Hamming distance in the randomly sampled test data. Second panel: Distribution of the observed phenotypic values (y) for each distance class. Third panel: Test R2 under the additive model and the 8-th order epistatic transformer for genotypes at different distance classes. The gap between the two curves is equal to for each distance class. Bottom panel: importance of specific pairwise and higher-order epistasis at different distance classes, measured by percent epistatic variance, equal to the gain in R2 by fitting an additional layer of MHA, normalized by . All metrics were calculated for models fit to one training sample consisting of 20% of randomly sampled genotypes. Error bars represent 1 standard deviation calculated by bootstrapping the test data with 10 replicates.

Higher-order epistasis in a multi-peak fitness landscape, consisted of four green fluorescent protein (GFP) orthologs (avGFP, amacGFP, ppluGFP2, cgreGFP).

a. PCA coordinates for the one-hot embeddings of the all protein genotypes. Genotypes are extremely centered around the four wild types (WT), which exhibit varying degrees of sequence divergence. b. Scatter plots of shared mutational effects among the four GFPs, fit using separate additive model with a nonlinear sigmoid activation function. c. Higher-order epistasis allows better generalization to distant regions in sequence space. Models were fit use single and double mutant data for all GFPs. Models were tested in different distance classes, each containing all genotypes at given Hamming distance to their corresponding WT sequence. Error bars represent 1 standard deviation calculated by resampling random 90% of the test genotypes with 10 replicates. d. Higher-order epistasis allows better generalization across local peaks. Rows: GFP orthologs used to train the models. Columns: GFP orthologs used to test performance of the trained model. at top right corner of each panel measures the proportion of total variance explained by specific epistasis among up to 8 positions. Error bars correspond to 1 standard deviation across three model replicates with different training and test genotypes.