Source code for fast_forward.interaction_plots


'''
Functions for plotting fitted interaction distributions
'''

import matplotlib.pyplot as plt
import numpy as np
import pickle
from .interaction_distribution import INTERACTIONS

X_LABELS={'bonds': 'Distance [Å]',
          'angles': 'Angle [°]',
          'dihedrals': 'Angle [°]',
          'distances': 'Distance [Å]'}

def _plotter(data, atom_list, inter_type, ax):

    cols = ['#6970E0', '#E06B69']
    needed_keys = [key for key in list(data.keys()) if key != 'x']
    for idx, key in enumerate(needed_keys):
        ax.plot(data['x'],
                data[key],
                c=cols[idx],
                label=key)
    ax.legend()
    ax.set_title(f'{atom_list} {inter_type}')
    ax.set_xlim(INTERACTIONS[inter_type].get('bins')[0],
                INTERACTIONS[inter_type].get('bins')[-1])
    ax.set_xlabel(X_LABELS[inter_type])


def _plotter_distance_distribution(data, ax, y_lower_threshold: float = 0.01):
    """
    Plot distance distribution curves and automatically zoom in on
    regions where the signal exceeds a given fraction of its maximum.

    Parameters
    ----------
    data : dict
        Dictionary containing the distance distribution data.
        Must include a key ``"x"`` for the shared x-axis values.
        All other keys correspond to y-data series to be plotted and
        must be array-like and of the same length as ``data["x"]``.
    ax : :class:`matplotlib.axes._axes.Axes`
        Axes object on which the curves will be plotted.
    y_lower_threshold : float, optional
        Fraction of each curve's maximum used to determine the
        region of interest. Only x-values where ``y > y.max() * threshold``
        are considered when computing the zoomed x-axis limits.
        Default is ``0.01``
    """
    cols = ['#6970E0', '#E06B69']
    needed_keys = [key for key in list(data.keys()) if key != 'x']
    x_min = 100
    x_max = 0
    x_pad = 0.25

    for idx, key in enumerate(needed_keys):
        ax.plot(data['x'],
                data[key],
                c = cols[idx],
                label=key)
        threshold = np.max(data[key]) * y_lower_threshold
        significant_indices = np.where(data[key] > threshold)[0]

        if significant_indices.size > 0: # Only update if we found significant data
            x_min = np.min([x_min, data['x'][np.min(significant_indices)]])
            x_max = np.max([x_max, data['x'][np.max(significant_indices)]])
    ax.yaxis.set_ticks([])
    ax.set_xlim(x_min - x_pad, x_max + x_pad)

[docs] def make_distribution_plot(fit_data, save_plot_data=None, axarr=None, name='distribution_plots'): ''' Parameters ---------- fit_data: dict Dictionary containing distributions and fitting parameters for interactions. Nested as {interaction_type: {group_name: {'data': distribution, 'fitted_params': list(params)}} axarr: :class:`matplotlib.axes._axes.Axes` array of axes to plot the fitted distributions on save_plot_data: bool if True, save the underlying data for plots as a pickle file name: str name of the output file (default: distribution_plots) ''' total_interactions = sum([len(fit_data[i]) for i in fit_data.keys()]) if not axarr: ncols = 5 nrows = -(total_interactions // -5) # upside-down floor division if nrows == 0: print("No interactions to plot!") return fig, axarr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4)) count = 0 for interaction_type in fit_data.keys(): for atom_list in fit_data[interaction_type].keys(): _plotter(fit_data[interaction_type][atom_list], atom_list, interaction_type, axarr.flatten()[count]) count += 1 # set ylabel if len(axarr.shape) == 2: for ax in axarr[:,0]: ax.set_ylabel('Probability Density') else: axarr[0].set_ylabel('Probability Density') if save_plot_data: pickle.dump(fit_data, open('plot_data.p', 'wb')) # remove unused axes for ax in axarr.flatten()[count:]: fig.delaxes(ax) # need to make room for the title fig.subplots_adjust(hspace = 0.3) fig.savefig(f'{name}.png', bbox_inches='tight')
[docs] def make_matrix_plot(matrix, atom_names, axarr=None, name='score_matrix'): ''' Parameters ---------- matrix: :class:`~numpy.ndarray` Quatratic 2D array representing the matrix atom_names: list List of atom names corresponding to the rows and columns of the matrix axarr: :class:`matplotlib.axes._axes.Axes` array of axes to plot the fitted distributions on name: str name of the output file (default: distribution_plots) ''' if not axarr: nrows = 1 ncols = 1 fig, axarr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*5+1, nrows*5)) cax = axarr.imshow(matrix, cmap='coolwarm', vmin=0, vmax=1) axarr.set_xticks(np.arange(len(atom_names)), labels=atom_names) axarr.set_yticks(np.arange(len(atom_names)), labels=atom_names) for i in range(matrix.shape[0]): for j in range(matrix.shape[1]): text = axarr.text(j, i, f'{matrix[i, j]:.2f}', ha="center", va="center", color="w", fontsize=8) cbar = plt.colorbar(cax) cbar.set_label('Score', rotation=270, labelpad=15) cbar.ax.tick_params(labelsize=10) axarr.set_title('Distance Score Matrix') # need to make room for the title fig.subplots_adjust(hspace = 0.3) fig.savefig(f'{name}.png', bbox_inches='tight')
[docs] def make_distances_distribution_plot(plot_data, atom_names, save_plot_data=False, axarr=None, name='distance_distribution_plots'): ''' Parameters ---------- matrix: :class:`~numpy.ndarray` Quatratic 2D array representing the matrix atom_names: list List of atom names corresponding to the rows and columns of the matrix axarr: :class:`matplotlib.axes._axes.Axes` array of axes to plot the fitted distributions on name: str name of the output file (default: distribution_plots) ''' natoms = len(atom_names) if not axarr: fig ,axarr = plt.subplots(natoms-1,natoms-1,figsize=(natoms*2,natoms),gridspec_kw={'wspace':0.05,'hspace':0.4}) parsed_names = [name.split("_") for name in atom_names] for i in range(natoms-1): resid1, name1 = parsed_names[i] for j in range(1,natoms): resid2, name2 = parsed_names[j] ax = axarr[i, j-1] if i < j: # plot only upper triangle of the matrix atoms_key = f'{resid1}_{name1}_{resid2}_{name2}' if atoms_key in plot_data['distances']: _plotter_distance_distribution(plot_data['distances'][atoms_key], ax) else: fig.delaxes(ax) # remove lower triangle of the matrix # add labels to the axes for i in range(natoms-1): axarr[0, i].set_title(atom_names[i+1], fontsize=14) axarr[i, i].set_ylabel(atom_names[i], fontsize=14) axarr[i, i].set_xlabel(X_LABELS['distances']) # add legend next to last plot axarr[natoms-2, natoms-2].legend(loc='upper left', fontsize=10, bbox_to_anchor=(-1, 0.75), frameon=False) fig.suptitle('Distance Distribution Plots', fontsize=16) if save_plot_data: pickle.dump(plot_data, open('plot_data_distribution.p', 'wb')) fig.savefig(f'{name}.png', bbox_inches='tight')