from openmm.app import *
from openmm import *
from openmm.unit import picosecond, second, nanometer, kelvin, millisecond, angstrom
from sys import stdout
import time

# Load in the PDB strucure
pdb = PDBFile("openmm_help/step4_input.pdb")
# prmtop_file = AmberPrmtopFile("amber/glycine.prmtop")
# inpcrd_file = AmberInpcrdFile("amber/glycine.inpcrd")

# Specifiy the forcefield
forcefield = ForceField("charmm36m.xml")

modeller = Modeller(pdb.topology, pdb.positions)
modeller.deleteWater()
modeller.addSolvent(forcefield, boxSize=Vec3(5, 5, 5))


# Combine the molecular topology and the forcefield
system = forcefield.createSystem(
    modeller.topology,
    nonbondedMethod=PME,
    # nonbondedCutoff=0.5 * angstrom,
    constraints=HBonds,
)

time_step = 0.004 * picosecond
total_time = 10 * picosecond
pdb_frames = 25

steps = int(total_time / time_step)

# The parameters set are temperature, friction coefficient, and timestep.
integrator = LangevinMiddleIntegrator(260 * kelvin, 1 / picosecond, time_step)

platform = Platform.getPlatform("CUDA")
simulation = Simulation(modeller.topology, system, integrator, platform)
simulation.context.setPositions(modeller.positions)

# Perform local energy minimization
print(
    f"Current energy is {simulation.context.getState(getEnergy=True).getPotentialEnergy()}"
)
print("Minimizing energy...")
simulation.minimizeEnergy(maxIterations=100)
print("Energy minimized!")
print(
    f"Current energy is {simulation.context.getState(getEnergy=True).getPotentialEnergy()}"
)

# Write the trajectory to a file called "output.pdb"
simulation.reporters.append(PDBReporter("output.pdb", steps // pdb_frames))

# Report infomation to the screen as the simulation runs
simulation.reporters.append(
    StateDataReporter(
        stdout,
        steps // 100,
        step=True,
        totalSteps=steps,
        progress=True,
        potentialEnergy=True,
        temperature=True,
        density=True,
    )
)

start_time = time.time()
print(f"Starting simulation, {steps} steps")
simulation.step(steps)
print(f"Simulation finished in {time.time() - start_time} seconds")