Interactive Plots with Chemical Structures

It’s often useful to be able to associate chemical structures with a set of points on a scatter plot. While it’s easy to do this with commercial software like Spotfire or Vortex, I haven’t found an easy way to integrate an interactive plot like this into a Python script.  In this post, I’ll cover how I was able to generate an interactive scatter plot with about a page of Python code, most of which was boilerplate.  I was able to pull this off by integrating Dash, a Python library for interactive dashboards from the nice folks who brought you the Plotly plotting library, with the RDKit.  For a quick view of what the application does, check out the movie below.   For a higher quality version, try this YouTube link.  Actually, just grab the code from GitHub and run it, the application looks a lot better in real life than it does in the video. 


At this point, this is more of a Saturday afternoon hack than a complete application.  I did this as proof of concept to prove to myself that it was possible. I wanted something that was pure Python and didn't require me to mess around with a lot of JavaScript.  I also wanted something that could be easily installed and would be responsive.  My ultimate goal is to provide the ability to integrate plots like this into a Jupyter notebook and make them part of my standard Cheminformatics workflow.  Anyhow, let's dive in.   I'm going to go into a lot of code and detail here.  If you just want to grab the app, it's here on GitHub

We start by importing a bunch of Python libraries.  The ones at the top are what we need for Plotly and Dash, the next set provides the necessary RDKit functionality, and the last two enable us to display those images in a web page.   We’ll look at these more closely below.  
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objs as go
from dash.dependencies import Input, Output
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolsToGridImage
import base64
from io import BytesIO
Next, we read in output from a t-distributed stochastic neighbor embedding calculation (tsne) that we created in my last post.  This plot shows the chemical space occupied by a set of ERK2 inhibitors from the DUD-E database.  We have a set of 79 active molecules and 4554 decoys, for more information on the dataset, see my last post.  We can use the "query" function in Pandas to create separate dataframes with the active and decoy molecules. 
df = pd.read_csv("tsne.csv")
active = df.query("is_active == 1")
decoy = df.query("is_active == 0")
Next, we generate a scatter plot using some code that I stole adapted from one of the Dash examples.  There are a couple of things to note here.  I used go.Scattergl rather than go.Scatter.  The difference here is that go.Scattergl uses WebGL to generate the scatterplot and is much faster.  I tried go.Scatter and it was so slow that it wasn’t usable.  Note that we generate a go.Scattergl object for each series (active, decoy) that we extracted from the dataframe above.  In the “layout” part of the graph definition, we turn off tooltips when hovering over a point and set the mouse dragging mode to select points.
graph_component = dcc.Graph(
    id='tsne',
    config={'displayModeBar': False},
    figure={
        'data': [
            go.Scattergl(
                x=decoy.X,
                y=decoy.Y,
                mode='markers',
                opacity=0.7,
                marker={
                    'size': 5,
                    'color': 'orange',
                    'line': {'width': 0.5, 'color': 'white'}
                },
                name="Decoy"
            ),
            go.Scattergl(
                x=active.X,
                y=active.Y,
                mode='markers',
                opacity=0.7,
                marker={
                    'size': 10,
                    'color': 'blue',
                    'line': {'width': 0.5, 'color': 'white'}
                },
                name="Active"
            )
        ],
        'layout': go.Layout(
            height=400,
            xaxis={'title': 'X'},
            yaxis={'title': 'Y'},
            margin={'l': 40, 'b': 40, 't': 10, 'r': 10},
            legend={'x': 1, 'y': 1},
            hovermode=False,
            dragmode='select'
        )
    }
)
We then define another component that will hold the structure image which will be generated by the RDKit MolsToGridImage function. Note that each of these components is given an id ("tnse" and "structure-image"), this will be important later.
image_component = html.Img(id="structure-image")
With the components defined we can create a Dash application object.   I haven’t spent time experimenting with style sheets yet.
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
The application's layout is initialized with HTML divs containing the two objects we created above.  
app.layout = html.Div([
    html.Div([graph_component]),
    html.Div([image_component])
])

So far this is pretty boring, now for the fun part.  We will define a callback that is called whenever points are selected in the plot.  The callback starts with a decorator that defines the inputs and outputs for the callback.  The first argument of the input and output are the ids of the components that provide the input and accept the output.  The second argument in the input and output is the field in the component that will provide the input and accept the output.  In other words, we will take the input from the "selectedData" field in "tsne", process this, and put the output into a field called "src" in "structure-image".
@app.callback(
    Output('structure-image', 'src'),
    [Input('tsne', 'selectedData')])
Finally, we define the function that creates the plot.  We start by defining the number of structures to display in the plot and the number of structures per row.  In order to optimize performance, I set the maximum number of structures to display to 12.

In order to pass the image to the web page, we need to encode it in a string as a base64 encoded image.  We need an empty image to display when noting is selected so we define a base64 string that encodes an image with one white pixel and put this into the variable "empty plot".

The part of the function that does the real work starts by checking whether anything has been selected.  The selections are passed in the variable "selectedData", which is a dictionary with mulitple fields.  The field "points" is a list of the selected points.  We can get the indices of the selected points from the field "pointIndex" in each entry in the list.

Once we have the list of selected indices, we can extract the information we need from the dataframe with the Pandas "iloc" method.  At this point it's pretty simple to generate the image with the RDKit MolsToGridImage function.  As a final step, we can encode the images as base64 and return in a form that can be displayed in a web page.
def display_selected_data(selectedData):
    max_structs = 12
    structs_per_row = 6
    empty_plot = ""
    if selectedData:
        if len(selectedData['points']) == 0:
            return empty_plot
        match_idx = [x['pointIndex'] for x in selectedData['points']]
        smiles_list = [Chem.MolFromSmiles(x) for x in list(df.iloc[match_idx].SMILES)]
        name_list = list(df.iloc[match_idx].Name)
        active_list = list(df.iloc[match_idx].is_active)
        name_list = [x + " " + str(y) for (x, y) in zip(name_list, active_list)]
        img = MolsToGridImage(smiles_list[0:max_structs], molsPerRow=structs_per_row, legends=name_list)
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        encoded_image = base64.b64encode(buffered.getvalue())
        src_str = 'data:image/png;base64,{}'.format(encoded_image.decode())
    else:
        return empty_plot
    return src_str

Finally, we create a main function to run the program.  We use the socket library to get the IP address, then launch the server,   We can then open a web browser and access the web address output by the server.

if __name__ == '__main__':
    import socket
    hostname = socket.gethostname()
    IPAddr = socket.gethostbyname(hostname)
    app.run_server(debug=True,host=IPAddr)

As mentioned above, this was more of a proof of concept than a full-featured application.  I'm planning to continue to develop this into something that can be used in Jupyter notebooks.  It's also my hope that others will fork the project and turn it into something useful.  If you do, please let me know.


Comments

Popular posts from this blog

We Need Better Benchmarks for Machine Learning in Drug Discovery

AI in Drug Discovery 2023 - A Highly Opinionated Literature Review (Part I)

Getting Real with Molecular Property Prediction