# Source code for aiida_quantumespresso.common.workchain.base.restart

# -*- coding: utf-8 -*-
# pylint: disable=inconsistent-return-statements,no-member
"""Base implementation of WorkChain class that implements a simple automated restart mechanism for calculations."""
from __future__ import absolute_import

from aiida import orm
from aiida.common import exceptions
from aiida.common.lang import override
from aiida.engine import ExitCode, CalcJob, WorkChain, ToContext, append_

from aiida_quantumespresso.common.exceptions import UnexpectedCalculationFailure

[docs]class BaseRestartWorkChain(WorkChain):
"""Base restart work chain

This work chain serves as the starting point for more complex work chains that will be designed to run a calculation
that might need multiple restarts to come to a successful end. These restarts may be necessary because a single
calculation run is not sufficient to achieve a fully converged result, or certain errors maybe encountered which
are recoverable.

This work chain implements the most basic functionality to achieve this goal. It will launch calculations,
restarting until it is completed successfully or the maximum number of iterations is reached. It can recover from
errors through error handlers that can be attached dynamically through the register_error_handler decorator.

The idea is to sub class this work chain and leverage the generic error handling that is implemented in the few
outline methods. The minimally required outline would look something like the following::

cls.setup
while_(cls.should_run_calculation)(
cls.run_calculation,
cls.inspect_calculation,
)

Each of these methods can of course be overriden but they should be general enough to fit most calculation cycles.
The run_calculation method will take the inputs for the calculation process from the context under the key
inputs. The user should therefore make sure that before the run_calculation method is called, that the to be
used inputs are stored under self.ctx.inputs. One can update the inputs based on the results from a prior
calculation by calling an outline method just before the run_calculation step, for example::

cls.setup
while_(cls.should_run_calculation)(
cls.prepare_calculation,
cls.run_calculation,
cls.inspect_calculation,
)

Where in the prepare_calculation method, the inputs dictionary at self.ctx.inputs is updated before the next
calculation will be run with those inputs.

The _calculation_class attribute should be set to the CalcJob class that should be run in the loop.
"""
_verbose = False
_calculation_class = None
_error_handler_entry_point = None

def __init__(self, *args, **kwargs):
super(BaseRestartWorkChain, self).__init__(*args, **kwargs)

if self._calculation_class is None or not issubclass(self._calculation_class, CalcJob):
raise ValueError('no valid CalcJob class defined for _calculation_class attribute')

[docs]    @override

# If an error handler entry point is defined, load them. If the plugin cannot be loaded log it and pass
if self._error_handler_entry_point is not None:
for entry_point_name in get_entry_point_names(self._error_handler_entry_point):
try:
self.logger.info("loaded the '%s' entry point for the '%s' error handlers category",
entry_point_name, self._error_handler_entry_point)
except exceptions.EntryPointError as exception:
self.logger.warning("failed to load the '%s' entry point for the '%s' error handlers: %s",
entry_point_name, self._error_handler_entry_point, exception)

@classmethod
def define(cls, spec):
# yapf: disable
super(BaseRestartWorkChain, cls).define(spec)
spec.input('max_iterations', valid_type=orm.Int, default=orm.Int(5),
help='Maximum number of iterations the work chain will restart the calculation to finish successfully.')
spec.input('clean_workdir', valid_type=orm.Bool, default=orm.Bool(False),
help='If True, work directories of all called calculation will be cleaned at the end of execution.')
spec.exit_code(101, 'ERROR_MAXIMUM_ITERATIONS_EXCEEDED',
message='The maximum number of iterations was exceeded.')
spec.exit_code(102, 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE',
message='The calculation failed for an unknown reason, twice in a row.')

[docs]    def setup(self):
"""Initialize context variables that are used during the logical flow of the BaseRestartWorkChain."""
self.ctx.calc_name = self._calculation_class.__name__
self.ctx.unexpected_failure = False
self.ctx.restart_calc = None
self.ctx.is_finished = False
self.ctx.iteration = 0

[docs]    def should_run_calculation(self):
"""Return whether a new calculation should be run.

This is the case as long as the last calculation has not finished successfully and the maximum number of
restarts has not yet been exceeded.
"""
return not self.ctx.is_finished and self.ctx.iteration < self.inputs.max_iterations.value

[docs]    def run_calculation(self):
"""Run the next calculation, taking the input dictionary from the context at self.ctx.inputs."""
from aiida_quantumespresso.utils.mapping import prepare_process_inputs

self.ctx.iteration += 1

try:
unwrapped_inputs = self.ctx.inputs
except AttributeError:
raise AttributeError('no calculation input dictionary was defined in self.ctx.inputs')

inputs = prepare_process_inputs(self._calculation_class, unwrapped_inputs)
calculation = self.submit(self._calculation_class, **inputs)

self.report('launching {}<{}> iteration #{}'.format(self.ctx.calc_name, calculation.pk, self.ctx.iteration))

[docs]    def inspect_calculation(self):
"""Analyse the results of the previous calculation and call the error handlers when necessary."""
calculation = self.ctx.calculations[self.ctx.iteration - 1]

# Done: successful completion of last calculation
if calculation.is_finished_ok:

# Perform an optional sanity check. If it returns an ExitCode this means an unrecoverable situation was
# detected and the work chain should be aborted. If it returns False, the sanity check detected a problem
# but has handled the problem and we should restart the cycle.
result = self._handle_calculation_sanity_checks(calculation)  # pylint: disable=assignment-from-no-return

if isinstance(result, ExitCode):
# No need to reset the unexpected_failure because the work chain will terminate due to the exit code
self.report('{}<{}> finished successfully, but sanity check detected unrecoverable problem'.format(
self.ctx.calc_name, calculation.pk))
return result

if result is False:
# Reset the unexpected_failure since we are restarting the calculation loop
self.ctx.unexpected_failure = False
self.report('{}<{}> finished successfully, but sanity check failed, restarting'.format(
self.ctx.calc_name, calculation.pk))
return

self.report('{}<{}> completed successfully'.format(self.ctx.calc_name, calculation.pk))
self.ctx.restart_calc = calculation
self.ctx.is_finished = True
return

# Unexpected: calculation was killed or an exception occurred, trigger unexpected failure handling
if calculation.is_excepted or calculation.is_killed:
return self._handle_unexpected_failure(calculation)

# Failed: here the calculation is Finished but has a non-zero exit status, initiate the error handling
try:
exit_code = self._handle_calculation_failure(calculation)
except UnexpectedCalculationFailure as exception:
exit_code = self._handle_unexpected_failure(calculation, exception)

return exit_code

[docs]    def results(self):
"""Attach the outputs specified in the output specification from the last completed calculation."""
calculation = self.ctx.calculations[self.ctx.iteration - 1]

if calculation.is_failed and self.ctx.iteration >= self.inputs.max_iterations.value:
# Abort: exceeded maximum number of retries
self.report('reached the maximum number of iterations {}: last ran {}<{}>'.format(
self.inputs.max_iterations.value, self.ctx.calc_name, calculation.pk))
return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED

self.report('work chain completed after {} iterations'.format(self.ctx.iteration))

for name, port in self.spec().outputs.items():

try:
except ValueError:
if port.required:
self.report("required output '{}' was not an output of {}<{}>".format(
name, self.ctx.calc_name, calculation.pk))
else:
self.out(name, node)
if self._verbose:
self.report("attaching the node {}<{}> as '{}'".format(node.__class__.__name__, node.pk, name))

[docs]    def on_terminated(self):
"""Clean the working directories of all child calculations if clean_workdir=True in the inputs."""
super(BaseRestartWorkChain, self).on_terminated()

if self.inputs.clean_workdir.value is False:
self.report('remote folders will not be cleaned')
return

cleaned_calcs = []

for called_descendant in self.node.called_descendants:
if isinstance(called_descendant, orm.CalcJobNode):
try:
called_descendant.outputs.remote_folder._clean()  # pylint: disable=protected-access
cleaned_calcs.append(str(called_descendant.pk))
except (IOError, OSError, KeyError):
pass

if cleaned_calcs:
self.report('cleaned remote folders of calculations: {}'.format(' '.join(cleaned_calcs)))

def _handle_calculation_sanity_checks(self, calculation):
"""Perform a sanity check of a calculation that finished ok.

Calculations that were marked as successful by the parser may still have produced outputs that do not make sense
but were not detected by the code and so were not highlighted as warnings or errors. The consistency of the
outputs can be checked here. If an unrecoverable problem is found, the function should return the appropriate
exit code to abort the work chain. If the probem can be fixed with a restart calculation, this function should
adapt the inputs as an error handler would and return False. This will signal to the work chain that a new
calculation should be started. If None is returned, the work chain assumes that the outputs produced by the
calculation are good and nothing will be done.

:param calculation: the calculation whose outputs should be checked for consistency
:return: ExitCode if the work chain is to be aborted, False if a new calculation should be launched
"""

def _handle_calculation_failure(self, calculation):
"""Call the attached error handlers if any to attempt to correct the cause of the calculation failure.

The registered error handlers will be called in order based on their priority until a handler returns a report
that instructs to break. If the last executed error handler defines an exit code, that will be returned to
instruct the work chain to abort. Otherwise the work chain will continue the cycle.

:param calculation: the calculation that finished with a non-zero exit status
:return: ExitCode if the work chain is to be aborted
:raises UnexpectedCalculationFailure: if no error handlers were registered or no errors were handled.
"""
is_handled = False
handler_report = None

if not hasattr(self, '_error_handlers') or not self._error_handlers:
raise UnexpectedCalculationFailure('no calculation error handlers were registered')

# Sort the handlers based on their priority in reverse order
handlers = sorted(self._error_handlers, key=lambda x: x.priority, reverse=True)

for handler in handlers:
handler_report = handler.method(self, calculation)

# If at least one error is handled, we consider the calculation failure handled.
if handler_report and handler_report.is_handled:
self.ctx.unexpected_failure = False
is_handled = True

# After certain error handlers, we may want to skip all other error handling
if handler_report and handler_report.do_break:
break

# If none of the executed error handlers reported that they handled an error, the failure reason is unknown
if not is_handled:
raise UnexpectedCalculationFailure('calculation failure was not handled')

# The last called error handler may not necessarily have returned a handler report
if handler_report:
return handler_report.exit_code

return

def _handle_unexpected_failure(self, calculation, exception=None):
"""Handle an unexpected failure.

This occurs when a calculation excepted, was killed or finished with a non-zero exit status but no errors were
handled. If this is the second consecutive unexpected failure the work chain is aborted.

:param calculation: the calculation that failed in an unexpected way
:param exception: optional exception or error message to log to the report
:return: ExitCode if this is the second consecutive unexpected failure
"""
if exception:
self.report('{}'.format(exception))

if self.ctx.unexpected_failure:
self.report('failure of {}<{}> could not be handled for the second consecutive time'.format(
self.ctx.calc_name, calculation.pk))
return self.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE

self.ctx.unexpected_failure = True
self.report('failure of {}<{}> could not be handled, restarting once more'.format(
self.ctx.calc_name, calculation.pk))