Building a molecular graph neural network model with Jaqpotpy

In this blog we'll create a graph-based classification model that:
- Uses molecular structures (SMILES) as input
- Predicts AMES mutagenicity using a Graph Neural Network
- Gets deployed on Jaqpot's platform
Here's how we'll break down the process:
1. Set up Dependencies and Data
First, we import required libraries and obtain our AMES toxicity train, validation and test data:
import torch from torch_geometric.loader import DataLoader from jaqpotpy import Jaqpot from jaqpotpy.descriptors.graph import SmilesGraphFeaturizer from jaqpotpy.datasets import SmilesGraphDataset from jaqpotpy.models.torch_geometric_models.graph_neural_network import ( GraphConvolutionNetwork, pyg_to_onnx, ) from jaqpotpy.models.trainers.graph_trainers import BinaryGraphModelTrainer from tdc.single_pred import Tox # Ames dataset data = Tox(name="AMES") data_splits = data.get_split() def split_to_list(split): return data_splits[split]["Drug"].to_list(), data_splits[split]["Y"].to_list() # Get train, validation, and test sets train_smiles, train_y = split_to_list("train") val_smiles, val_y = split_to_list("valid") test_smiles, test_y = split_to_list("test")
2. Create Molecular Graph Dataset
The next step is to initialize a featurizer in order to create Node (or Edge) level features for the Graph Dataset. In this example, we will use default node features from Jaqpotpy. One can also choose to specify which features to use according to the RDKit library documentation. Because the dataset is not very heavy on memory we will also precompute the graph features.
# Initialise the featurizer featurizer = SmilesGraphFeaturizer(include_edge_features=False) featurizer.set_default_config() # Create datasets train_dataset = SmilesGraphDataset( smiles=train_smiles, y=train_y, featurizer=featurizer ) train_dataset.precompute_featurization() val_dataset = SmilesGraphDataset(smiles=val_smiles, y=val_y, featurizer=featurizer) val_dataset.precompute_featurization() test_dataset = SmilesGraphDataset(smiles=test_smiles, y=test_y, featurizer=featurizer) test_dataset.precompute_featurization()
3. Define the model
Set up the Graph Neural Network model. In our case we will use a simple Graph Convolution Neural Network. For simplicity we will not mess with the network's hyperparameters. However you can freely choose the depth, activation function and more components of the architecture.
model = GraphConvolutionNetwork( input_dim=featurizer.get_num_node_features(), )
4. Train and evaluate the model
As usual in PyTorch we will need to define an optimizer and a loss function for the training procedure. Then we will use the BinaryGraphModelTrainer to train, evaluate and obtain the classification metrics.
# Pytorch configuration optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) loss = torch.nn.BCEWithLogitsLoss() # Initialise trainer trainer = BinaryGraphModelTrainer( model=model, n_epochs=20, optimizer=optimizer, loss_fn=loss, ) # Create data loaders train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # Train the model trainer.train(train_loader, val_loader) # Obtain metrics to evaluate performance loss, metrics, conf_matrix = trainer.evaluate(test_loader)
5. Deploy on Jaqpot
The final step is to deploy the model as a web service. Here we will follow the standard procedure as seen in previous blogs with minor modifications. The importan thing here is that in order to deploy a model we will need to convert it into an onnx format.
onnx_model = pyg_to_onnx(model, featurizer) # Login as a user jaqpot = Jaqpot() jaqpot.login() jaqpot.deploy_torch_model( onnx_model, featurizer=featurizer, name="Graph Neural Network", description="Ames classification" target_name="AMES", visibility="PRIVATE", task="binary_classification", )
Key takeaways
- You can use SmilesGraphDataset() paired with a SmilesGraphFeaturizer() to create graph neural network features from SMILES strings.
- Create a Graph Convolution Network for molecular property classification.
- The BinaryGraphModelTrainer() can be used for training, evaluation and prediction.
- The modele is deployed after being converted into an ONNX format.
Try on your own
You can customize this example by:
- Using different molecular datasets
- Adjusting the GNN architecture
- Tuning hyperparameters