"""
A test for running several queries (both failing and successful) in
parallel in hopes of uncovering races and such.
"""

import copy
import Queue
import sys
import threading
import time
import unittest

import regressiontest


def iterTests(suite):
	for member in suite:
		if isinstance(member, unittest.TestSuite):
			for t in iterTests(member):
				yield copy.deepcopy(t)
		else:
			yield copy.deepcopy(member)


class TestStatistics(object):
	def __init__(self):
		self.results = []
		self.oks, self.fails, self.total = 0, 0, 0
		self.globalStart = time.time()
		self.lastTimestamp = time.time()+1
		self.timeSum = 0
	
	def add(self, test, result):
		if result.wasSuccessful():
			self.oks += 1
		else:
			self.fails += 1
			print "\nFail: %s\n"%test.id()
		self.total += 1
		self.results.append(result)

	def getReport(self):
		return ("%d of %d bad.")%(self.fails, self.fails+self.oks)

	def save(self, target):
		f = open(target, "w")
		cPickle.dump(self.results, f)
		f.close()


class ParallelRunner(object):
	def __init__(self, nSimul, nTotal):
		self.nSimul, self.nTotal = nSimul, nTotal
		self.inputQueue = Queue.Queue()
		self.threadPool = {}
		self.stats = TestStatistics()
		self.nextThreadId = 0

	def _runTestInThread(self, test, outputQueue, id):
		result = test.defaultTestResult()
		test.run(result)
		outputQueue.put(("testFinished", id, test, result))

	def _spawnThread(self, test):
		newThread = threading.Thread(target=self._runTestInThread,
			args=(test, self.inputQueue, self.nextThreadId))
		self.threadPool[self.nextThreadId] = newThread
		newThread.test = test
		self.nextThreadId += 1
		newThread.setDaemon(True)
		newThread.start()

	def _handleEvent(self, ev):
		msg, id, test, result = ev
		assert msg=="testFinished"
		self.stats.add(test, result)
		del self.threadPool[id]

	def _collectRemainingThreads(self):
		timeoutStart = time.time()
		try:
			while len(self.threadPool):
				self._handleEvent(self.inputQueue.get(block=True, timeout=10))
				if time.time()-timeoutStart>30:
					break
		except Queue.Empty:
			print "\n*************%d hung threads!"%len(self.threadPool)

	def run(self, suite):

		def iterSuiteEndless():
			while True:
				for t in iterTests(suite):
					yield t

		tests = iterSuiteEndless()
		try:
			while self.stats.total<=self.nTotal-self.nSimul:
				while len(self.threadPool)<self.nSimul:
					self._spawnThread(tests.next())
				self._handleEvent(self.inputQueue.get(block=True, timeout=40))
				sys.stdout.write("\r"+self.stats.getReport())
				sys.stdout.flush()
		except Queue.Empty:
			print "**** Too many hung threads, giving up"
			print "Tests in thread pool:\n  %s"%(
				"\n  ".join(t.test.id() for t in self.threadPool.values()))
		self._collectRemainingThreads()


def main():
	loader = unittest.TestLoader()
	if False:
		suite = unittest.TestSuite()
		suite.addTest(regressiontest.Maxrec0Test('test'))
	else:
		suite = loader.loadTestsFromModule(regressiontest)

	runner = ParallelRunner(7, 2000)
	runner.run(suite)
	print "\n%s"%runner.stats.getReport()
	for result in runner.stats.results:
		for errTest, tb in result.errors:
			print errTest
			print tb
		for errTest, tb in result.failures:
			print errTest
			print tb


if __name__=="__main__":
	main()
