# Copyright 2021-2022 VMware, Inc.
# All rights reserved. -- VMware Confidential

""" This is the abstracted layer of task management including the base classes
    and common algorithm to run a group of tasks/workflows in parallel.
"""

from copy import deepcopy

import abc
import logging
import sys
import time

from .Constants import *
from .Utils import createNotification

if sys.version_info >= (3, 4):
   ABC = abc.ABC
else:
   # scons build is still using python 2.
   ABC = abc.ABCMeta('ABC', (), {})

log = logging.getLogger(__name__)

# Common states
STATE_PENDING = 'PENDING'
STATE_RUNNING = 'RUNNING'
STATE_SUCCEEDED = 'SUCCEEDED'
STATE_FAILED = 'FAILED'
STATE_TIMEDOUT = 'TIMEDOUT'
STATE_MAX_RETRY = 'MAX_RETRY_REACHED'

# Workflow states
STATE_EARLY_SUCCEEDED = 'EARLY_SUCCEEDED'
STATE_EARLY_FAILED = 'EARLY_FAILED'
STATE_NEXT_PHASE = 'NEXT_PHASE'

# WorkflowPhase states
STATE_TO_EARLY_SUCCEEDED = 'TO_EARLY_SUCCEEDED'
STATE_TO_EARLY_FAILED = 'TO_EARLY_FAILED'
STATE_TO_NEXT_PHASE = 'TO_NEXT_PHASE'

class ImageRunnable(ABC):
   """ ImageRunnable is the base class for tasks. It handle the general logic
       for state transition. It also handles common logic for retry and timeout.
   """

   # Default timeout; subclasses can override it with better estimation.
   TIMEOUT = 300 # seconds

   # Default retry; subclasses can override it.
   MAX_RETRY = 1

   # Default sleep time.
   POLL_INTERVAL = 10

   # A dict that contains the state relation graph. For each key/value pair,
   # the key is a state name, the value is a list of the next state that
   # can be transferred to. For ending states, the value is None.
   stateTransitionGraph = dict()

   # The state transition function map. The key is a state name, and the value
   # is a function to transit the current state to other states when
   # condition is satisfied.
   stateTransitionFuncMap = dict()

   # The intial state, usually is STATE_PENDING.
   initialState = STATE_PENDING

   # The long run states that runs as async.
   longRunStates = list()

   # The state notification function map.
   stateNotificationMap = dict()

   def reset(self):
      self._triedNum = 0
      self._startTime = None
      self._endTime = None

   def __init__(self, name, entity, parentTask=None, maxRetry=None,
                timeout=None):
      """ Constructor of ImageRunnable.

          name: The task name.
          entity: The DPU IP.
          parentTask: The parent task if any.
          maxRetry: The maximum retry count.
          timeout: The timeout period.
      """
      self._name = name
      self._entity = entity
      self._parentTask = parentTask
      self._maxTry = (maxRetry if maxRetry != None else self.MAX_RETRY) + 1
      self._state = self.__class__.initialState
      self.reset()
      self._errorNotifications = []
      self._timeout = timeout if timeout else self.TIMEOUT
      self._notifications = dict()
      self._lastUpdateTime = time.time()

   def updateParentTaskNotification(self, msgId, args, type_):
      """ API to create notification and update parent task.
      """
      notif = createNotification(msgId, msgId, args, args, type_)
      canUpdate = (self._parentTask and
                   hasattr(self._parentTask, 'updateNotifications'))
      if type_ == ERROR:
         # Collect error notifications.
         self._errorNotifications.append(notif)
      elif canUpdate:
         self._parentTask.updateNotifications([notif])

      if (canUpdate and self.atEndState() and self._state != STATE_SUCCEEDED):
         # Only add error notification when task does not succeed.
         self._parentTask.updateNotifications(self._errorNotifications)
         self._errorNotifications = []

   def getPollInterval(self):
      return self.POLL_INTERVAL

   def start(self):
      """ Start the runnable object: set start time and ending time; then
          call the private subclass start function.
      """

      # Set up start time and end time.
      if self._startTime == None:
         self._startTime = time.time()
         self._endTime = self._startTime + self._timeout

      # Process retry.
      self._triedNum += 1

      log.info('Starting runnable %s with %s', self._name, self._entity)
      self._start()
      if self._state == STATE_RUNNING:
         self.updateParentTaskNotification(
            TaskStarted, [self._name, self._entity], INFO)

   def processSucceeded(self, modifyState=True):
      """ Set state to be succeeded when needed and notify.
      """
      if modifyState:
         self._state = STATE_SUCCEEDED
      self.updateParentTaskNotification(TaskSucceeded,
                                        [self._name, self._entity],
                                        INFO)

   def processFailed(self):
      """ STATE_FAILED will be transferred to STATE_PENDING for retry case.
          Otherwise, end at STATE_FAILED state.
      """
      if self._maxTry > 1:
         if self._triedNum >= self._maxTry:
            self._state = STATE_MAX_RETRY
            self.updateParentTaskNotification(
               TaskMaxRetry,
               [self._name, self._entity, str(self._maxTry)],
               INFO)
            log.error('Runnable (%s %s) reached maximum retry', self._name,
                      self._entity)
            return

         self.updateParentTaskNotification(
            TaskRetry, [self._name, self._entity], INFO)
         self._state = STATE_PENDING

   def updateState(self):
      """ Check the progress; change state if the expected event happened.
          Time out if ending time is reached.

          For short run state, transit immediately to avoid long waiting.
      """
      oldState = None

      while True:
         if self.atEndState():
            break

         oldState = self._state
         try:
            trasitionFunc = self.__class__.stateTransitionFuncMap[self._state]
            if trasitionFunc:
               trasitionFunc(self)
            else:
               self.updateParentTaskNotification(
                  TaskStateTransitionError, [self._name, self._entity], ERROR)
               log.error('No transition function for state %s is not provided',
                      self._state)
         except KeyError as e:
            log.error('No transition function for state %s is not provided',
                      self._state)

         if self._state != oldState:
            log.debug('Runnable (%s %s) moves from %s to %s', self._name,
                      self._entity, oldState, self._state)

         if (self._state in self.__class__.longRunState or
             self._state == oldState):
            break

      if time.time() >= self._endTime:
         self._state = STATE_TIMEDOUT
         self.updateParentTaskNotification(
            TaskTimeout, [self._name, self._entity], INFO)
         log.error('Runnable (%s %s) timedout', self._name, self._entity)

   def atEndState(self):
      """ An ending state is a state that has no transition function.
          when no retry, STATE_FAILED is an ending state.
      """
      if self.__class__.stateTransitionGraph[self._state] == None:
         return True
      return self._maxTry == 1 and self._state == STATE_FAILED

   def isSuccess(self):
      """ Check the runnable succeeded or not.
      """
      return self._state == STATE_SUCCEEDED

class RunnableGroup(object):
   """ A class runs a group of ImageRunnable in parallel.
   """

   def __init__(self, runnables):
      """ The constructor.

          runnables: A group of ImageRunnables.
      """
      self._runnables = list(runnables)
      self._succeededRunnables = []
      self._failedRunnables = []
      self._finishedNum = 0
      self._runnableNum = len(runnables)

      for runnable in self._runnables:
         if runnable._state != runnable.__class__.initialState:
            log.error('Runnable (%s %s) failed before start',
                      runnable.name, runnable.entity)
            self._failedRunnables.append(runnable)

      for runnable in self._failedRunnables:
         self._runnables.remove(runnable)

   def run(self):
      """ The common algorithm to run a group ImageRunnables in parallel.
      """
      while self._finishedNum != self._runnableNum:
         for runnable in self._runnables:
            runnable.updateState()

         newlyFinished = []
         for runnable in self._runnables:
            if runnable.atEndState():
               self._finishedNum += 1
               newlyFinished.append(runnable)
               log.info('Runnable (%s %s) finished with state %s',
                        runnable._name, runnable._entity, runnable._state)
               if runnable.isSuccess():
                  self._succeededRunnables.append(runnable)
               else:
                  self._failedRunnables.append(runnable)

         for runnable in newlyFinished:
            self._runnables.remove(runnable)

         if self._runnables:
            time.sleep(self._runnables[0].getPollInterval())

   def succeeded(self):
      """ Return True if all ImageRunnable succeeded; otherwise, False.
      """
      return len(self._succeededRunnables) == self._runnableNum

class WorkflowPhase(ImageRunnable):
   """ Base class for workflow phase.
   """

   # Don't retry each phase but retry the workflow.
   MAX_RETRY = 0

   @classmethod
   def patchStateTransitionGraph(cls, stateTransitionGraph):
      """ Helper method to enhance the state transition graph when
          ImageRunnable is ussed as a workflow phase.
      """
      from copy import deepcopy
      stateTransitionGraph = deepcopy(stateTransitionGraph)
      stateTransitionGraph[STATE_SUCCEEDED] = \
         [STATE_TO_EARLY_SUCCEEDED, STATE_TO_EARLY_FAILED, STATE_TO_NEXT_PHASE]
      stateTransitionGraph[STATE_FAILED].append(STATE_TO_EARLY_FAILED)
      return stateTransitionGraph

class Workflow(ImageRunnable):
   """ A workflow is formed by a sequence of ImageRunnables. These
       ImageRunnables run in order.

       Retry happens from last failed ImageRunnables than from beginning.
   """

   def __init__(self, workflowPhases, name, entity, parentTask=None,
                maxRetry=None, timeout=None):
      """ Constructor of Workflow.

          workflowPhases: The workflow phases.
          name: The workflow name.
          entity: The target of the workflow.
          parentTask: The parent task of the workflow.
          maxRetry: The maximum retry count.
          timeout: The timeout period.
      """
      super(Workflow, self).__init__(name, entity, parentTask, maxRetry,
                                     timeout)
      self._workflowPhases = workflowPhases
      self._currentPhaseIndex = 0
      self._lastSucceeded = -1

   def _processState(self):
      """ Adjust workflow state based on the current workflow phase state.
      """
      if self._currentPhase._state == STATE_RUNNING:
         self._state = STATE_RUNNING
      elif self._currentPhase._state == STATE_TO_EARLY_FAILED:
         self._state = STATE_EARLY_FAILED
      elif self._currentPhase._state == STATE_TO_EARLY_SUCCEEDED:
         self._state = STATE_EARLY_SUCCEEDED
      elif self._currentPhase._state == STATE_TO_NEXT_PHASE:
         self._state = STATE_NEXT_PHASE
         self._lastSucceeded = self._currentPhaseIndex
      elif self._currentPhase._state == STATE_FAILED:
         self.processFailed()
         if self._maxTry > 1 and self._triedNum < self._maxTry:
            self._resetworkflow()
            self._state = STATE_NEXT_PHASE
            self._triedNum += 1
      elif (self._currentPhaseIndex == len(self._workflowPhases) - 1
            and self._currentPhase._state == STATE_SUCCEEDED):
         self._state = STATE_SUCCEEDED
         self.processSucceeded()
      elif self._state == STATE_EARLY_SUCCEEDED:
         self.processSucceeded(False)

   def _startPhase(self):
      """ Start the current workflow phase.
      """
      self._currentPhase = self._workflowPhases[self._currentPhaseIndex]
      self._currentPhase.start()
      self._processState()

   def _start(self):
      """ Private method to start workflow as ImageRunnable. Called by
          ImageRunnable start method.
      """
      self._currentPhaseIndex = 0
      self._startPhase()

   def _resetworkflow(self):
      """ Reset the states of workflow phases to be retried.
      """
      for i in range(self._lastSucceeded + 1, len(self._workflowPhases)):
         self._workflowPhases[i].reset()
      self._currentPhaseIndex = self._lastSucceeded

   def updateWorkflow(self):
      """ Update the workflow state based on workflow phase state.
      """
      self._currentPhase.updateTask()
      self._processState()

      if self._state == STATE_NEXT_PHASE:
         self._currentPhaseIndex += 1
         self._startPhase()

   def isSuccess(self):
      """ Check the runnable succeeded or not.
      """
      return (super(Workflow, self).isSuccess() or
              self._state == STATE_EARLY_SUCCEEDED)

