#--
# EstTT.py -- class for ESTimating Traversal Times via various
# algorithms.  Use "strategy" design pattern to handle alg
# framework.
# J. A. Templon, NIKHEF, PDP Group
# $Id$
# Source: $URL$

#--

# the strategy interface (ie generic algorithm interface):
class FindEstimate(object):
    def algorithm(self, lrms, vo, debug):
        """
        lrms == concrete instance of class lrms (snapshot of LRMS state)
        vo   == virtual organization for which you want estimate
                note: so far this usually maps to unix group of mapped user
        debug: 0 is don't debug, 1 is to print debugging info
        """
        pass

# put concrete algorithm instances here

class Gott(FindEstimate):
    """The original Gott ideas are embodied here.  If you want to
    use the old scheme of looking at jobs matching both the VO
    AND queue like the old system did, do a call like
       Estimator.strategy.queue = 'atlas'
    before actually using the Estimator.estimate() method.
    """
    def __init__(self):
        FindEstimate.__init__(self)
        self.queue = ''
    def algorithm(self, lrms, vo, debug):
        return ett(lrms, self.queue, vo, algorithm='average',debug=debug)

class WaitTime(FindEstimate):
    """This is Gott modified to use longest-wait times instead
    of average wait times.  If you want to
    use the old scheme of looking at jobs matching both the VO
    AND queue like the old system did, do a call like
       Estimator.strategy.queue = 'atlas'
    before actually using the Estimator.estimate() method.
    """
    def __init__(self):
        FindEstimate.__init__(self)
        self.queue = ''
    def algorithm(self, lrms, vo, debug):
        return ett(lrms, self.queue, vo, algorithm='longest',debug=debug)

class Throughput(FindEstimate):
    """This class is based on looking at job throughputs and calculating
    how long it will take to empty all waiting jobs for the desired VO
    based on the observe througput.  Motivation is that Gott & WaitTime
    do very poorly when the job population is not stable, ie when the
    rates of job arrival and job starts are not roughly the same.
    """
    def __init__(self):
        FindEstimate.__init__(self)
        self.twin = 10
        """
        Tunable parameter specifying how far back (how many jobs) in the system
        start history to look when calculating job throughput.  IE if twin
        (Througput WINdow) is 4, code will look at no more than the last four
        jobs started to see how many jobs per second are being started.
        """
    def algorithm(self, lrms, vo, debug):

        # get rid of the silliest case first

        if lrms.slotsUp < 1:  # system has no job slots at all
            return 7777777

        # gather a couple lists to help us decide

        def allrunfunc(job):
            return job.get('state') == 'running'
        def vowaitfunc(job):
            return job.get('state') == 'queued' and job.get('group') == vo
        allRunningJobs  = lrms.matchingJobs(allrunfunc)
        waitingJobsMyVO = lrms.matchingJobs(vowaitfunc)

        # another very silly case; this VO has waiting jobs, but absolutely no
        # jobs are running.  This is either a short transient that will only last
        # for one sched cycle, or else something is wrong.  Assume the latter.
        
        if len(allRunningJobs) < 1 and len(waitingJobsMyVO) > 0 :
            return 7777777
        
        # another silly case: someone just restarted the server.
        # symptom (at least on PBS): no free slots, lots of running
        # jobs, many (or all) of which have no 'start' record.

        # note this is not the only symptom ... there may be free
        # slots but still no jobs with historical info.  Some code
        # later on to handle this in most_recent_run ...

        if lrms.slotsFree == 0:
            nr = len(allRunningJobs)
            nns = 0                     # Number with No Start
            for j in allRunningJobs:
                if j.get('start') == None: nns = nns + 1
            if float(nns)/nr >= 0.5:
                return 7777777
            
        nw = len(waitingJobsMyVO)
        if nw < 0:
            print "found negative number of queued jobs!" ; sys.exit(1)
        elif nw == 0:
            if lrms.slotsFree > 0:
                return lrms.schedCycle / 2
            else:
                if len(allRunningJobs) == 0:
                    return lrms.schedCycle / 2
                else:
                    most_recent = allRunningJobs[0]
                    startList = [ ]
                    for j in allRunningJobs:
                        tval = j.get('start')
                        if tval:  # jobs started during last poll cycle won't have 'start'
                            startList.append(tval)
                        else:
                            startList.append(lrms.schedCycle / 2)  # rough approx
                    startList.sort()  # last element will be most recent start
                    time_sys_full = lrms.now - startList[-1]
                    return time_sys_full + lrms.schedCycle/2
        else: # there are waiting jobs

            # first check one special case: the one where *all* waiting jobs
            # have been in the system for less than the schedule cycle
            # if there are enough slots to run them all, skip the throughput
            # calculation.

            waitList = [ ]
            for j in waitingJobsMyVO:
                waitList.append(lrms.now - j.get('qtime'))
            waitList.sort()
            if waitList[-1] < 1.1*lrms.schedCycle \
                   and len(waitList) < lrms.slotsFree:
                return lrms.schedCycle / 2

            # calcs based on throughput.  we know since we checked above that at
            # least *some* jobs are running.  Use jobs from our VO if we have
            # them, otherwise use list of all running jobs.
            
            def vorunfunc(job):
                return job.get('state') == 'running' and job.get('group') == vo
            runningJobsMyVO = lrms.matchingJobs(vorunfunc)
            if len(runningJobsMyVO) > 0:
                runlist = runningJobsMyVO
            else:
                runlist = allRunningJobs
            startList = [ ]
            for j in runlist:
                st = j.get('start')
                if st :
                    startList.append(st)
                else:
                    startList.append(lrms.now - lrms.schedCycle/2)
            startList.sort()

            # decide how many jobs to use to calc throughput; take minimum
            # of number running, number waiting, and the 'throughput window'
            # that can be set on the algorithm.  Note we have to use the
            # startList instead of runlist since there may be starts for
            # which we have no records yet (PBS caching).
            
            num = min(nw,len(startList),self.twin)
            period = lrms.now - startList[-num]
            thruput = float(num)/period   # jobs per sec

            if len(runningJobsMyVO) > 0:
                return int((nw+1) / thruput)
            else:
                # jobs are running but not yours, so assume you need to finish all
                # waiting jobs before your new one starts
                naw = len(lrms.jobs()) - len(runlist) # counts all waiting jobs
                return int((naw+1) / thruput)

# the "context" controls the strategy:
class TTEstimator(object):
    def __init__(self, strategy):
        self.strategy = strategy
    def estimate(self, lrms, vo, debug):
        return self.strategy.algorithm(lrms, vo, debug)

# there are several algorithms for what to do with queued jobs.
# define them here and provide means to select between them.

def ett_longest_queue_time(server,list) :
    est = 7777777
    qtlist = [ ]
    
    for j in list:
        if j.get('state') == 'queued' :
            qtlist.append(server.now - j.get('qtime'))
    qtlist.sort()
    est = qtlist[-1]   # sorted list so this is the longest waiting.

    return est

def ett_avg_queue_time(server,list) :
    est = 7777777
    # find average time in queue of all queued jobs
    count = 0
    sum = 0.0
    for j in list:
        if j.get('state') == 'queued' :
            count = count + 1
            sum = sum + (server.now - j.get('qtime'))
    avgtime = sum / count
    # avg time in queue likely to be half of expected time in queue
    # => return twice avg as estimate.
    est = int(2*avgtime)

    return est

# make dict of functions to allow run-time selection
_ALGS = {
    'average' : ett_avg_queue_time,
    'longest' : ett_longest_queue_time }

# adjusted_ett provides a central place to make any changes to
# return values; right now it just puts a lower limit of 5 on
# the lower limit.  Could be used e.g. to change units.
def adjusted_ett(rawval,period):
    if rawval < period:
        return int(period / 2.)
    else:
        return int(rawval)

def ett(server,queue,voname='',algorithm='longest',debug=0) :

    if debug: print 'running for VO', voname

    jfilt = {'state' : 'queued'}
    
    if queue != '':              jfilt['queue'] = queue
    if voname not in ['', None]: jfilt['group'] = voname

    jobsrunning = {'state': 'running'}
    
    nq = server.nmatch(jfilt)
    if debug: print 'found', nq, 'queued jobs in queue \'' + queue + \
       '\' for VO', voname
    if nq > 0: # there are queued jobs for this VO
        est = _ALGS[algorithm](server,server.jobs_last_query())
    elif server.slotsFree > 0 :
        if debug: print 'found', server.slotsFree, 'free job slots'
        est = 0
    else :
        server.nmatch(jobsrunning) # get all running jobs
        est = ett_most_recent_run(server.now,
                                      server.jobs_last_query(),
                                      debug)

    return adjusted_ett(est,server.schedCycle)

def ett_most_recent_run(now,joblist,debug=0) :

    # find most recently run job
    # make list of start times
    
    starts = list()
    for j in joblist:
        st = j.get('start')
        if st :
            starts.append(st)

    if len(starts) > 0:
        starts.sort()
        last_started = starts[-1]  # last in list - start with largest timestamp
        if debug : print now, last_started
        timefilled = now - last_started
        if debug : print 'most recent run', timefilled
        return timefilled
    else:
        return 777777   # error code for 'i don't know'
