In this tutorial, we will learn how to use TorchProtein to solve different sequence-based protein property prediction tasks. This tutorial will include five different types of tasks including protein-wise property prediction, residue-wise property prediction, contact prediction, protein-protein interaction (PPI) prediction and protein-ligand interaction (PLI) prediction.
- Protein Sequence Representation Model
- Task 1: Protein-wise Property Prediction
- Task 2: Residue-wise Property Prediction
- Task 3: Contact Prediction
- Task 4: Protein-Protein Interaction (PPI) Prediction
- Task 5: Protein-Ligand Interaction (PLI) Prediction
Protein Sequence Representation Model
TorchProtein defines diverse sequence-based models to learn protein sequence representations.
In this tutorial, we use a two-layer 1D CNN as the protein sequence representation model for all considered tasks.
First, let’s define such a model via the models.ProteinCNN
module.
from torchdrug import models
model = models.ProteinCNN(input_dim=21,
hidden_dims=[1024, 1024],
kernel_size=5, padding=2, readout="max")
Task 1: Protein-wise Property Prediction
The first type of tasks we would like to solve is to predict protein-wise properties. We take the Beta-lactamase activity prediction task as an example, which aims to predict mutation effects on the TEM-1 beta-lactamase protein.
Before defining the dataset, we need to first define the transformations we want to perform on proteins.
We consider two transformations. (1) To lower the memory cost of sequence-based models, it is a common practice to truncate overlong protein sequences.
In TorchProtein, we can define a protein truncation transformation by specifying the maximum length of proteins via the max_length
parameter and whether to truncate from random residue or the first residue via the random
parameter.
(2) Besides, since we want to use residue features as node features for sequence-based models, we also need to define a view change transformation for proteins.
During dataset construction, we can pass the composition of two transformations as an argument.
from torchdrug import transforms
truncate_transform = transforms.TruncateProtein(max_length=200, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])
We then build the dataset via datasets.BetaLactamase
, in which the dataset file will be automatically downloaded.
In this dataset, the label of each sample is a real number indicating the fitness value of the protein.
By turning on residue_only
option, TorchProtein will use data.Protein.from_sequence_fast
to load proteins and thus be much faster.
We can get the pre-specified training, validation and test splits by the split()
method.
from torchdrug import datasets
dataset = datasets.BetaLactamase("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)
train_set, valid_set, test_set = dataset.split()
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))
The label of first sample: 0.9426838159561157
train samples: 4158, valid samples: 520, test samples: 520
To perform Beta-lactamase activity prediction, we wrap the CNN encoder into the tasks.PropertyPrediction
module which appends a task-specific MLP prediction head upon CNN.
from torchdrug import tasks
task = tasks.PropertyPrediction(model, task=dataset.tasks,
criterion="mse", metric=("mae", "rmse", "spearmanr"),
normalization=False, num_mlp_layer=2)
Now we can train our model. We set up an optimizer for our model, and put everything together into an Engine instance. It takes about 2 minutes to train our model for 10 epochs on this task. We finally evaluate its performance on the validation set.
import torch
from torchdrug import core
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=64)
solver.train(num_epoch=10)
solver.evaluate("valid")
The evaluation result may be similar to the following.
mean absolute error [scaled_effect1]: 0.249482
root mean squared error [scaled_effect1]: 0.304326
spearmanr [scaled_effect1]: 0.44572
Task 2: Residue-wise Property Prediction
The second type of tasks we consider is to predict residue-wise properties. We take the secondary structure prediction task as an example.
We first build the dataset via datasets.SecondaryStructure
, in which we use the cb513
test set. For residue-level tasks, we usually keep the whole sequence of proteins, so we only use the ProteinView
transform here.
The target
field denotes the secondary structure (coil, strand or helix) of each residue, and the mask
field denotes whether each secondary structure label is valid or not.
Both fields are with the same length of protein sequence.
dataset = datasets.SecondaryStructure("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=protein_view_transform)
train_set, valid_set, test_set = dataset.split(["train", "valid", "cb513"])
print("SS3 label: ", dataset[0]["graph"].target[:10])
print("Valid mask: ", dataset[0]["graph"].mask[:10])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))
SS3 label: tensor([2, 2, 2, 0, 0, 0, 0, 0, 2, 2])
Valid mask: tensor([True, True, True, True, True, True, True, True, True, True])
train samples: 8678, valid samples: 2170, test samples: 513
To perform secondary structure prediction, we wrap the CNN encoder into the tasks.NodePropertyPrediction
module which appends a task-specific MLP prediction head upon CNN.
task = tasks.NodePropertyPrediction(model, criterion="ce",
metric=("micro_acc", "macro_acc"),
num_mlp_layer=2, num_class=3)
We train the model for 5 epochs, taking about 5 minutes, and finally evaluate it on the validation set.
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=128)
solver.train(num_epoch=5)
solver.evaluate("valid")
The evaluation result may be similar to the following.
macro_acc: 0.629119
micro_acc: 0.624727
Task 3: Contact Prediction
The third task we would to solve is to predict whether any pair of residues contact or not in the folded structure, i.e., performing contact prediction.
We first build the dataset via datasets.ProteinNet
.
The residue_position
field denotes the 3D coordinates of each residue,
and the mask
field denotes whether each residue position is valid or not.
Both fields are with the same length of protein sequence.
dataset = datasets.ProteinNet("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=protein_view_transform)
train_set, valid_set, test_set = dataset.split()
print("Residue position: ", dataset[0]["graph"].residue_position[:3])
print("Valid mask: ", dataset[0]["graph"].mask[:3])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))
Residue position: tensor([[ 2.0940e+00, 2.0000e-03, -1.2420e+00],
[ 5.1260e+00, -2.0210e+00, -2.3290e+00],
[ 7.5230e+00, 6.1500e-01, -3.6610e+00]])
Valid mask: tensor([True, True, True])
train samples: 25299, valid samples: 224, test samples: 34
To perform contact prediction, we wrap the CNN encoder into the tasks.ContactPrediction
module which appends a task-specific MLP prediction head upon CNN.
Two residues with sequence gap larger than gap
are seen as interacted if their Euclidean distance is within threshold
.
Different from previous tasks, the maximum truncation length max_length
is defined in the task now, since the truncation behavior is different on the test set in the contact prediction task.
For the test set, to save memory, we will split the test sequences into several blocks according to max_length
.
task = tasks.ContactPrediction(model, max_length=500, random_truncate=True, threshold=8.0, gap=6,
criterion="bce", metric=("accuracy", "prec@L5", "prec@5"), num_mlp_layer=2)
Since the raw training set contains a lot of samples, and only small batch size can be used in this task, we use a subset of the raw training set with 1000 samples for training. We train the model for 1 epoch, taking about 4 minutes, and finally evaluate it on the validation set.
from torch.utils import data as torch_data
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
sub_train_set = torch_data.random_split(train_set, [1000, len(train_set) - 1000])[0]
solver = core.Engine(task, sub_train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=1)
solver.train(num_epoch=1)
solver.evaluate("valid")
The evaluation result may be similar to the following.
accuracy: 0.973398
prec@5: 0.123214
prec@L5: 0.0894395
Task 4: Protein-Protein Interaction (PPI) Prediction
The fourth task we consider is to predict the binding affinity of two interacting proteins, i.e., performing PPI affinity prediction.
We first build the dataset via datasets.PPIAffinity
, in which each sample is a pair of proteins, and it is associated with a continuous label indicating the binding affinity.
Since we now need to perform transformation on both proteins, we need to specify keys
in the transformation function.
truncate_transform_ = transforms.TruncateProtein(max_length=200, keys=("graph1", "graph2"))
protein_view_transform_ = transforms.ProteinView(view="residue", keys=("graph1", "graph2"))
transform_ = transforms.Compose([truncate_transform_, protein_view_transform_])
dataset = datasets.PPIAffinity("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=transform_)
train_set, valid_set, test_set = dataset.split()
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))
The label of first sample: -12.2937
train samples: 2421, valid samples: 203, test samples: 326
To perform PPI affinity prediction, we wrap the CNN encoder into the tasks.InteractionPrediction
module which appends a task-specific MLP prediction head upon CNN.
task = tasks.InteractionPrediction(model, task=dataset.tasks,
criterion="mse", metric=("mae", "rmse", "spearmanr"),
normalization=False, num_mlp_layer=2)
We train the model for 10 epochs, taking about 2 minutes, and finally evaluate it on the validation set.
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=64)
solver.train(num_epoch=10)
solver.evaluate("valid")
The evaluation result may be similar to the following.
mean absolute error [interaction]: 2.0611
root mean squared error [interaction]: 2.6741
spearmanr [interaction]: 0.502957
Task 5: Protein-Ligand Interaction (PLI) Prediction
The fourth task we consider is to predict the binding affinity of a protein and a small molecule (i.e., ligand). We take the PLI prediction on BindingDB as an example.
We first build the dataset via datasets.BindingDB
, in which each sample is a pair of protein and ligand,
and it is associated with a continuous label indicating the binding affinity.
We use the holdout_test
set for test.
truncate_transform_ = transforms.TruncateProtein(max_length=200, keys="graph1")
protein_view_transform_ = transforms.ProteinView(view="residue", keys="graph1")
transform_ = transforms.Compose([truncate_transform_, protein_view_transform_])
dataset = datasets.BindingDB("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=transform_)
train_set, valid_set, test_set = dataset.split(["train", "valid", "holdout_test"])
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))
The label of first sample: 5.823908740944319
train samples: 7900, valid samples: 878, test samples: 5230
To perform PLI prediction, we require an additional ligand graph encoder to extract the representations of ligands.
We define a 4-layer Graph Isomorphism Network (GIN) as the ligand graph encoder.
After that, we wrap the CNN encoder and the GIN encoder into the tasks.InteractionPrediction
module which appends a task-specific MLP prediction head upon CNN and GIN.
model2 = models.GIN(input_dim=66,
hidden_dims=[256, 256, 256, 256],
batch_norm=True, short_cut=True, concat_hidden=True)
task = tasks.InteractionPrediction(model, model2=model2, task=dataset.tasks,
criterion="mse", metric=("mae", "rmse", "spearmanr"),
normalization=False, num_mlp_layer=2)
We train the model for 5 epochs, taking about 3 minutes, and finally evaluate it on the validation set.
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
gpus=[0], batch_size=16)
solver.train(num_epoch=5)
solver.evaluate("valid")
The evaluation result may be similar to the following.
mean absolute error [affinity]: 0.916763
root mean squared error [affinity]: 1.22093
spearmanr [affinity]: 0.60658