Source code for fast_forward.interaction_fit


'''
functions for fitting interaction distributions
'''

import numpy as np
from lmfit.models import GaussianModel
from lmfit import Parameters
from MDAnalysis.units import constants
import lmfit
from collections import defaultdict
from vermouth.molecule import Interaction

def _is_part_of_dihedral(angle_atoms, dihedrals):
    """
    Check if an angle is part of a dihedral

    Parameters
    ----------
    angle_atoms: list
        list of atom indices in the angle
    dihedrals: list
        list of dihedrals in the system

    Returns
    -------
    bool
        True if angle is part of a dihedral, False otherwise
    """
    match = False
    for dih in dihedrals:
        match = (
            np.array_equal(angle_atoms, dih[0:3]) or
            np.array_equal(angle_atoms, dih[1:4]) or
            np.array_equal(angle_atoms[::-1], dih[0:3]) or
            np.array_equal(angle_atoms[::-1], dih[1:4])
        )
    return match

def _gaussian_fitter(x, y, initial_center, initial_sigma, initial_amplitude):
    """
    Fit a Gaussian function to an input distribution

    Parameters
    ----------
    x: :class:`~numpy.ndarray`
        x of data
    y: :class:`~numpy.ndarray`
        y of data
    initial_center: dict
        dictionary of values for lmfit to use as starting parameters
    initial_sigma: dict
        dictionary of values for lmfit to use as starting parameters
    initial_amplitude: dict
        dictionary of values for lmfit to use as starting parameters

    Returns
    -------
    gaussian_result: lmfit.ModelResult
        lmfit fit result
    """

    gaussian = GaussianModel()
    params = gaussian.make_params(center=initial_center,
                                  sigma=initial_sigma,
                                  amplitude=initial_amplitude
                                  )
    gaussian_result = gaussian.fit(y, params, x=x)

    return gaussian_result

def _gaussian_generator(x, params):
    """
    Generate a gaussian function from fitted parameters
    """
    mod = GaussianModel(x=x)
    pars = Parameters()
    pars.add("center", params['center'].value)
    pars.add("sigma", params['sigma'].value)
    pars.add("amplitude", params['amplitude'].value)
    fitted_distribution = mod.eval(pars, x=x)
    return fitted_distribution
def _periodic_gaussian_generator(x, c, s, a):
    """
    Generate a gaussian function from fitted parameters across x with periodicity
    """
    terms = 10
    period = 2*np.pi
    y = np.zeros_like(x)
    for k in range(-terms, terms+1):
        y += np.exp(-0.5 * ((x - c + k * period) / s)**2)
    return y * a
[docs] class InteractionFitter: """ Class to fit interactions """ interactions_dict: defaultdict # <-- Add this line def __init__(self, precision, temperature, constraint_converter, max_dihedrals, dihedral_scaling): ''' Parameters ---------- precision: int precision to round values to for writing temperature: int temperature of interaction distribution in boltzmann inversion constraint_converter: int threshold above which to convert bonds to constraints max_dihedrals: int maximum number of dihedrals to fit proper dihedrals with ''' self.__dihedrals = None self.precision = precision self.temperature = temperature self.kb = constants["Boltzmann_constant"] self.constraint_converter = constraint_converter self.max_dihedrals = max_dihedrals self.dihedral_scaling = dihedral_scaling # this will store the interactions self.interactions_dict = defaultdict(list) self.fit_parameters = defaultdict(dict) self.plot_parameters = defaultdict(dict) def _bonds_fitter(self, data, group_name): """ Fit bonds Parameters ---------- data: :class:`~numpy.ndarray` histogram of bond data group_name: str names of the atoms involved in the interaction joined by a "_" """ x, y = data.T gaussian_fit = _gaussian_fitter(x, y, initial_center=dict(value=x[y.argmax()]), initial_sigma=dict(value=x.std()), initial_amplitude=dict(value=y.max()) ) initial_center = gaussian_fit.params["center"].value initial_sigma = gaussian_fit.params["sigma"].value # need this here because mdanalysis read gromacs coords in angstroms but need in nm. center = np.round(initial_center / 10, self.precision) sigma = np.round((self.kb * self.temperature) / ((initial_sigma / 10) ** 2), self.precision) self.fit_parameters['bonds'][group_name] = [center, sigma] self.plot_parameters['bonds'][group_name] = {'x': x, 'Distribution': y, 'Fitted': _gaussian_generator(x, gaussian_fit.params)} def _angles_fitter(self, data, group_name): """ Fit angles Parameters ---------- data: :class:`~numpy.ndarray` histogram of bond data group_name: str names of the atoms involved in the interaction joined by a "_" """ x, y = data.T gaussian_fit = _gaussian_fitter(x, y, initial_center=dict(value=x[y.argmax()], min=x[y.argmax()] - 20, max=x[y.argmax()] + 20), initial_sigma=dict(value=x.std(), min=x.std() / 4, max=x.std() * 1.5), initial_amplitude=dict(value=y.max(), min=0) ) initial_center = gaussian_fit.params["center"].value initial_sigma = gaussian_fit.params["sigma"].value center = np.round(initial_center, self.precision) sin_term = np.sin(np.deg2rad(float(center))) ** 2 var = np.deg2rad(initial_sigma) ** 2 sigma = np.round((self.kb * self.temperature) / (sin_term * var), self.precision) self.fit_parameters['angles'][group_name] = [center, sigma] self.plot_parameters['angles'][group_name] = {'x': x, 'Distribution': y, 'Fitted': _gaussian_generator(x, gaussian_fit.params)} # Fitting function for proper dihedrals def _proper_dihedral_model_function(self, params, x): """Computes the sum of cosines with the given parameters.""" y = np.zeros_like(x) num_terms = len(params) // 3 # Each term has k, n, x0 for i in range(1, num_terms): k = params[f'k{i}'] n = int(params[f'n{i}'].value) # Force n to be an integer x0 = params[f'x0_{i}'] y += k * (1 + np.cos(n * x - x0)) return y # Residual function for lmfit def _residuals(self, params, x, data): ''' Residual function for fitting proper dihedrals Parameters ---------- params: lmfit.Parameters Parameters object for fit x: :class:`~numpy.ndarray` x variable for dihedral data data: :class:`~numpy.ndarray` probability data for dihedral distribution Returns ------- model - data: residuals for fitting function to optimise ''' return self._proper_dihedral_model_function(params, x) - data def _dihedrals_fitter(self, data, group_name): ''' Fitter for dihedrals. Will try to fit both proper and improper dihedrals, deciding which to return based on the Akaike information criterion of the two fits Parameters ---------- data: :class:`~numpy.ndarray` histogram of bond data group_name: str names of the atoms involved in the interaction joined by a "_" ''' x = np.linspace(-np.pi, np.pi, 360) y = data.T[1] # take care of periodic effects for improper dihedrals x_gauss = np.linspace(-2*np.pi, 2*np.pi, 720) y_gauss = np.tile(y, 2) # first try fitting a gaussian to the data in case we have an improper dihedral gaussian_result = _gaussian_fitter(x_gauss[120:-120], y_gauss[120:-120], initial_center=dict(value=x[y.argmax()], min=-np.pi, max=np.pi), initial_sigma=dict(value=1, min=0, max=np.pi/3), initial_amplitude=dict(value=y.max()) ) # now do fitting for proper dihedrals # Iterate over different numbers of terms to find the optimal one best_aic = np.inf best_params = None single_params = None for num_terms in range(1, self.max_dihedrals + 1): params = lmfit.Parameters() for i in range(num_terms): params.add(f'k{i}', value=1.0) # Initial guess params.add(f'n{i}', value=i, vary=False) # Fixed integer frequency params.add(f'x0_{i}', value=0.0, min=-np.pi, max=np.pi) # Phase shift # Perform fitting minimizer = lmfit.Minimizer(self._residuals, params, fcn_args=(x, y)) result = minimizer.minimize() # Compute AIC (lower is better) aic = result.aic # Keep track of the best model if aic < best_aic: best_aic = aic best_params = result.params # Save the parameters for a single periodic function if num_terms == 2: single_params = result.params num_terms = len(best_params) // 3 # Each term has k, n, and x0 condition0 = best_aic < gaussian_result.aic condition1 = np.isclose(gaussian_result.params['sigma'].value, gaussian_result.params['sigma'].max) # compare the aic values to determine which type of dihedral we have # also make sure we don't have a very wide gaussian, where a single periodic function will suffice if condition0 or condition1: if not condition0 and condition1: num_terms = 2 best_params = single_params pars_out = [] for i in range(1, num_terms): x0 = best_params[f'x0_{i}'].value n = int(best_params[f'n{i}'].value) # Ensure n is integer pars_out.append([best_params[f'k{i}'].value, x0, n]) self.fit_parameters['dihedrals'][group_name] = {i: j for i, j in enumerate(pars_out)} self.plot_parameters['dihedrals'][group_name] = {'x': np.degrees(x), 'Distribution': y, 'Fitted': self._proper_dihedral_model_function(best_params, x)} else: # transform the centre back into the correct domain after fitting to account for periodicity. c0 = (gaussian_result.params['center'].value + (2*np.pi)) % (2*np.pi) - np.pi center = np.round(c0, self.precision) sigma = np.round((self.kb * self.temperature) / ((gaussian_result.params['sigma']) ** 2), self.precision) self.fit_parameters['dihedrals'][group_name] = [center, sigma] x_plot = np.degrees(((x+np.pi) % (2*np.pi)) - np.pi) fitted_improper_plot = _periodic_gaussian_generator(x, c0, gaussian_result.params['sigma'].value, gaussian_result.params['amplitude'].value) self.plot_parameters['dihedrals'][group_name] = {'x': x_plot[np.argsort(x_plot)], 'Distribution': y, 'Fitted': fitted_improper_plot} def _virtual_sites2_handler(self, data, group_name): self.fit_parameters['virtual_sites2'][group_name] = {'params': [data[0][0]]} def _virtual_sites2fd_handler(self, data, group_name): self.fit_parameters['virtual_sites2fd'][group_name] = {'params': [data[0][0]]} def _virtual_sites3_handler(self, data, group_name): self.fit_parameters['virtual_sites3'][group_name] = {'params': [data[0][0], data[0][1]]} def _virtual_sites3out_handler(self, data, group_name): self.fit_parameters['virtual_sites3out'][group_name] = {'params': [data[0][0], data[0][1], data[0][2]]} def _virtual_sites3fd_handler(self, data, group_name): self.fit_parameters['virtual_sites3fd'][group_name] = {'params': [data[0][0], data[0][1]]} def _virtual_sitesn_handler(self, data, group_name): self.fit_parameters['virtual_sitesn'][group_name] = None @property def dihedrals(self): return getattr(self, "__dihedrals", []) @dihedrals.setter def dihedrals(self, interaction_groups): self.__dihedrals = [dih for key in interaction_groups['dihedrals'] for dih in interaction_groups['dihedrals'][key]]
[docs] def fit_to_gmx(self, inter_type, group_name, atoms, vs_constructors): if inter_type == 'bonds': parameters = self.fit_parameters['bonds'][group_name] center, sigma = parameters for ag in atoms: if any(x in vs_constructors for x in ag): self.interactions_dict['bonds'].append(Interaction(atoms=list(ag), parameters=[1, center, sigma], meta={"comment": group_name})) else: if sigma < self.constraint_converter: self.interactions_dict['bonds'].append(Interaction(atoms=list(ag), parameters=[1, center, sigma], meta={"comment": group_name})) else: self.interactions_dict['bonds'].append(Interaction(atoms=list(ag), parameters=[1, center, 10000], meta={"ifdef": "FLEXIBLE", "comment": group_name})) self.interactions_dict['constraints'].append(Interaction(atoms=list(ag), parameters=[1, center], meta={"ifndef": "FLEXIBLE", "comment": group_name, "fc": sigma})) elif inter_type == 'constraints': parameters = self.fit_parameters['bonds'][group_name] center, sigma = parameters for ag in atoms: self.interactions_dict['constraints'].append(Interaction(atoms=list(ag), parameters=[1, center], meta={"comment": group_name})) elif inter_type == 'angles': parameters = self.fit_parameters['angles'][group_name] center, sigma = parameters # empirically derived. if sigma too big, angles get very unstable. sigma = min(sigma, 150) if _is_part_of_dihedral(atoms[0], self.dihedrals): # only assign type 10 if part of a dihedral and theta_0 < 160 if float(center) < 160: # empirically derived. For theta_0 > 160, significant ptl energy for type 10 at equilibrium, so enforce type 1. func_type_out = 10 else: print((f"WARNING: Angle {group_name} is part of a dihedral with equilibrium angle {center:.1f}°. " f"For theta_0 > 160°, the system may have high potential even energy at equilibrium. " f"This might cause instabilities.")) func_type_out = 10 else: func_type_out = 1 for ag in atoms: self.interactions_dict['angles'].append(Interaction(atoms=list(ag), parameters=[func_type_out, center, sigma], meta={"comment": group_name})) elif inter_type == 'dihedrals': parameters = self.fit_parameters['dihedrals'][group_name] if isinstance(parameters, list): center, sigma = parameters center = np.round(np.degrees(center), self.precision) for ag in atoms: self.interactions_dict['dihedrals'].append(Interaction(atoms=list(ag), parameters=[2, center, sigma], meta={"comment": group_name})) else: for ag in atoms: for i in parameters.values(): # factors derived from the fitting directly have negligible effects (~10^-3/4), # scaling them helps increase the strength of dihedral in the final interaction k = - i[0] * self.dihedral_scaling x0_deg = np.degrees(i[1]) n = i[2] self.interactions_dict['dihedrals'].append(Interaction(atoms=ag, parameters=[9, # function type np.round(x0_deg, self.precision), # center np.round(k, self.precision), # force constant int(n) # multiplicity ], meta={"comment": group_name, "group": group_name})) elif inter_type == 'virtual_sites2': parameters = self.fit_parameters['virtual_sites2'][group_name]['params'] for ag in atoms: self.interactions_dict['virtual_sites2'].append(Interaction(atoms=ag, parameters=[1, np.round(parameters[0], self.precision) ], meta={"comment": group_name} )) elif inter_type == 'virtual_sites2fd': parameters = self.fit_parameters['virtual_sites2fd'][group_name]['params'] for ag in atoms: self.interactions_dict['virtual_sites2'].append(Interaction(atoms=ag, parameters=[2, np.round(parameters[0], self.precision), ], meta={"comment": group_name} )) elif inter_type == 'virtual_sites3': parameters = self.fit_parameters['virtual_sites3'][group_name]['params'] for ag in atoms: self.interactions_dict['virtual_sites3'].append(Interaction(atoms=ag, parameters=[1, np.round(parameters[0], self.precision), np.round(parameters[1], self.precision), ], meta={"comment": group_name} )) elif inter_type == 'virtual_sites3fd': parameters = self.fit_parameters['virtual_sites3fd'][group_name]['params'] for ag in atoms: self.interactions_dict['virtual_sites3'].append(Interaction(atoms=ag, parameters=[2, np.round(parameters[0], self.precision), np.round(parameters[1], self.precision), ], meta={"comment": group_name} )) elif inter_type == 'virtual_sites3out': parameters = self.fit_parameters['virtual_sites3out'][group_name]['params'] for ag in atoms: self.interactions_dict['virtual_sites3'].append(Interaction(atoms=ag, parameters=[4, np.round(parameters[0], self.precision), np.round(parameters[1], self.precision), np.round(parameters[2], self.precision), ], meta={"comment": group_name} )) elif inter_type == 'virtual_sitesn': pars = [1] + [i+1 for i in atoms[0][1:]] self.interactions_dict['virtual_sitesn'].append(Interaction(atoms=[atoms[0][0]], parameters=pars, meta={"comment": group_name}))
[docs] def fit_interaction(self, data, atoms, group_name, inter_type, vs_constructors=[]): """ Fit an interaction for a group of atoms, and assign the fitted parameters to gromacs variables in self.interactions_dict Parameters ---------- data: :class:`~numpy.ndarray` histogram of input data atoms: list (lists of) atom indices involved in the given interaction group_name: str name of interaction group inter_type: str name of interaction type being analysed vs_constructors: list indices of atoms which are virtual sites. Cannot construct constraints from virtual sites, so these will be overwritten if found. """ func_dict = {'bonds': self._bonds_fitter, 'constraints': self._bonds_fitter, 'angles': self._angles_fitter, 'dihedrals': self._dihedrals_fitter, 'virtual_sites2': self._virtual_sites2_handler, 'virtual_sites2fd': self._virtual_sites2fd_handler, 'virtual_sites3': self._virtual_sites3_handler, 'virtual_sites3fd': self._virtual_sites3fd_handler, 'virtual_sites3out': self._virtual_sites3out_handler, 'virtual_sitesn': self._virtual_sitesn_handler } func_dict[inter_type](data, group_name) self.fit_to_gmx(inter_type, group_name, atoms, vs_constructors)