Visualizing Decision Trees

A 2016 paper by Wicker and Cooper, describing a molecular descriptor designed to capture molecular flexibility, popped up on Twitter this week.  This paper reminded me of the power of a simple decision tree.  Decision trees can often provide an efficient way of looking at the relationship between molecular descriptors and experimental data.   They can also provide a means of understanding the relationship between sets of experiments, particularly with pharmacokinetic data.

In this spirit, I thought I'd put together a quick post showing how to build and visualize a decision tree.  This post will also show off a couple of useful Python libraries that I've recently integrated into my workflow. 

In their paper, Wicker and Cooper use a set of 40,541 commercially available molecules, from the ZINC database, to establish a relationship between molecular flexibility and the ability of a molecule to crystallize.  The dataset is divided into two subsets.
  • “observed to crystallize” - molecules that occur in both ZINC and the Cambridge Crystallographic Database (CSD).  
  • “not observed to crystallize” - molecules found in ZINC but not in the CSD
The assumption here is that the molecules not found in the CSD were difficult to crystallize.  There could, of course, be a number of other reasons the “not observed to crystallize” set was not in the CSD, but let's just go with it. 

Time to write some code.  Let's start by importing the necessary Python libraries. 
import pandas as pd
import janitor
from sklearn import tree
from dtreeviz.trees import *
from sklearn.metrics import matthews_corrcoef

Here's a quick rundown on the libraries.
  • pandas - the go-to library for anything having to do with tabular data
  • janitor - the amazing PyJanitor library for cleaning data, more on this below
  • sklearn - the scikit-learn library for all things machine learning 
  • dtreeviz - cool visualizations of decision trees, more below
We start by reading the training set. 
train_df = pd.read_csv("train_desc_with_names.csv")
Let's take a look at the names of the columns in the dataframe.
train_df.columns
Index(['Unnamed: 0', 'MolWt', 'HeavyAtomMolWt', 'NumRadicalElectrons',
       'NumValenceElectrons', 'BalabanJ', 'BertzCT', 'Chi0', 'Chi0n', 'Chi0v',
       ...
       'R6', 'R7', 'R8', 'R9', 'RG10', 'AFRC', 'BFRC', 'nConf20', 'Label',
       'Label names'],
      dtype='object', length=223)
Note that some of the column names have spaces, which can be somewhat inconvenient. Fortunately, the PyJanitor library has functions that enable us to do all sorts of data cleaning. We can use the "clean_names" function to remove the spaces from the column names. We are also taking advantage of the very cool method chaining capability provided by PyJanitor.
train_df = (train_df.clean_names(case_type="preserve")
            .remove_columns(["Unnamed_0","Label_names"]))
Next, we can extract the x variables (descriptors) and the y variables (response) from the dataframe.
train_X = train_df.remove_columns(["Label"])
train_y = train_df.Label
Time to build the model, let's instantiate a DecisionTreeClassifier.
cls = tree.DecisionTreeClassifier(max_depth=2) 
Next, we can train the classifier.
cls.fit(train_X,train_y)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=2, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

Now that we have the model, we can use the dtreeviz package to create a visualization.  The scikit-learn package has a method for displaying decision trees (see Figure 6 in Wicker and Cooper for an example), but I haven't found it to be particularly useful.  I like the visualization in dtreeviz because it shows the relationships between the distributions and the decision tree.
feature_names = list(train_X.columns)
viz = dtreeviz(cls, train_X, train_y, feature_names = feature_names, 
               target_name = "Crystallize",class_names=["No","Yes"],scale=2)
viz




The plot above shows the decision points on the x-axis of the histogram showing the distribution.  The distributions in the terminal nodes are shown as pie charts.

It looks like the model is performing well, let's take a look at the performance on the test set.  First, we'll read the data and clean it with PyJanitor.
test_df = pd.read_csv("test_desc_with_names.csv")
test_df = (test_df.clean_names(case_type="preserve")
        .remove_columns(["Unnamed_0","Label_names"]))
test_X = test_df.remove_columns(["Label"])
test_y = test_df.Label
Now we can check the Matthews Correlation Coefficient for the test set.
pred = cls.predict(test_X)
matthews_corrcoef(test_y, pred)
0.7781038157491144
Not bad for a very simple model.  While it's true that there are a lot of sophisticated machine learning models out there, the humble decision tree can often provide an easy way of creating interpretable models.   The dtreeviz package is a nice add-on that enables you to quickly see what the decision tree is doing.  Hopefully, you'll find this useful and integrate it into your data analysis workflow.  

As usual, the code for this post is available in a Jupyter notebook on GitHub.  If you want to run the code without installing software, there is a version on Binder.




Comments

Popular posts from this blog

We Need Better Benchmarks for Machine Learning in Drug Discovery

AI in Drug Discovery 2023 - A Highly Opinionated Literature Review (Part I)

Getting Real with Molecular Property Prediction