import unittest
import usbtest
import usb

# Must keep this in sync with the driver
BUF_SIZE	= 1024

# Number of times to repeat each loopback test
REPEAT		= 1

# Standard USB definitions not provided by PyUSB
ENDPOINT_HALT	= 0

TIMEOUT		= 1000

class TestCase(unittest.TestCase):
	"Test bulk data transfers"

	def setUp(self):
		for config in usbtest.dev.configurations:
			for iface in config.interfaces:
				in_ep = None
				out_ep = None
				for ep in iface[0].endpoints:
					if ep.type == usb.ENDPOINT_TYPE_BULK:
						if ep.address & usb.ENDPOINT_IN:
							in_ep = ep
						else:
							out_ep = ep
				if in_ep and out_ep:
					self.bulk_iface = iface[0]
					self.in_ep = in_ep
					self.out_ep = out_ep
				break
			if self.bulk_iface:
				break;

		self.dev = usbtest.dev.open()
		self.dev.claimInterface(self.bulk_iface)

	def tearDown(self):
		self.dev.releaseInterface()
		del self.dev

	def setHalt(self, ep):
		# What, no setHalt() method in PyUSB?
		self.dev.controlMsg(usb.ENDPOINT_OUT
				| usb.TYPE_STANDARD | usb.RECIP_ENDPOINT,
				usb.REQ_SET_FEATURE, 0, ENDPOINT_HALT,
				ep.address, timeout=TIMEOUT)

	def getEpStatus(self, ep):
		low, high = self.dev.controlMsg(usb.ENDPOINT_IN
				| usb.TYPE_STANDARD | usb.RECIP_ENDPOINT,
				usb.REQ_GET_STATUS, 2, 0, ep.address,
				timeout=TIMEOUT)
		return (high << 8) | low

	def doLoopbackTest(self, size):
		out_data = []
		for i in range(0, size):
			out_data.append(i & 0xff)
		written = self.dev.bulkWrite(self.out_ep.address, out_data,
				TIMEOUT)
		self.failUnlessEqual(written, size)
		if (size < BUF_SIZE and size % self.out_ep.maxPacketSize == 0):
			self.dev.bulkWrite(self.out_ep.address, [], TIMEOUT)

		in_data = self.dev.bulkRead(self.in_ep.address, size, TIMEOUT)
		self.failUnlessEqual(len(in_data), size)

		for i in range(0, size):
			if in_data[i] != i & 0xff:
				self.fail("mismatch: sent %02x, received %02x"
						% (i & 0xff, in_data[i]))

	def testWholeBuffer(self):
		for i in range(0, REPEAT):
			self.doLoopbackTest(BUF_SIZE)

	def testTwoBuffers(self):
		for i in range(0, REPEAT):
			self.doLoopbackTest(2 * BUF_SIZE)

	def testShort(self):
		for i in range(0, REPEAT):
			for size in [1, 63, 64, 65, 511, 512, 513]:
				self.doLoopbackTest(size)

	def testStallOut(self):
		self.setHalt(self.out_ep)
		status = self.getEpStatus(self.out_ep)
		self.failUnless(status & (1 << ENDPOINT_HALT))
		written = self.dev.bulkWrite(self.out_ep.address, 511 * [0],
				TIMEOUT)

		# This appears to be broken on full-speed. PyUSB
		# returns 64, but the USB analyzer shows that even
		# though 64 bytes were indeed sent, the device
		# responded with a STALL handshake, so they shouldn't
		# be counted as successfully sent.
		#
		# self.failUnlessEqual(written, 0)

		self.dev.clearHalt(self.out_ep.address)
		status = self.getEpStatus(self.out_ep)
		self.failIf(status & (1 << ENDPOINT_HALT))
		self.doLoopbackTest(511)

	def testStallIn(self):
		self.setHalt(self.in_ep)
		status = self.getEpStatus(self.in_ep)
		self.failUnless(status & (1 << ENDPOINT_HALT))
		self.dev.bulkWrite(self.out_ep.address, 511 * [0], TIMEOUT)
		data = self.dev.bulkRead(self.in_ep.address, 511, TIMEOUT)
		self.failUnlessEqual(len(data), 0)

		self.dev.clearHalt(self.in_ep.address)
		status = self.getEpStatus(self.in_ep)
		self.failIf(status & (1 << ENDPOINT_HALT))
		data = self.dev.bulkRead(self.in_ep.address, 511, TIMEOUT)
		self.failUnlessEqual(len(data), 511)

	def testStallOutPartial(self):
		epsize = self.out_ep.maxPacketSize

		written = self.dev.bulkWrite(self.out_ep.address, epsize * [0],
				TIMEOUT)
		self.failUnlessEqual(written, epsize)

		self.setHalt(self.out_ep)
		status = self.getEpStatus(self.out_ep)
		self.failUnless(status & (1 << ENDPOINT_HALT))

		written = self.dev.bulkWrite(self.out_ep.address,
				(epsize - 1) * [0], TIMEOUT)

		# This appears to be broken on full-speed. PyUSB
		# returns 64, but the USB analyzer shows that even
		# though 64 bytes were indeed sent, the device
		# responded with a STALL handshake, so they shouldn't
		# be counted as successfully sent.
		#
		# self.failUnlessEqual(written, 0)

		self.dev.clearHalt(self.out_ep.address)
		status = self.getEpStatus(self.out_ep)
		self.failIf(status & (1 << ENDPOINT_HALT))

		written = self.dev.bulkWrite(self.out_ep.address,
				(epsize - 1) * [0], TIMEOUT)
		self.failUnlessEqual(written, epsize - 1)

		data = self.dev.bulkRead(self.in_ep.address, epsize * 2 - 1,
				TIMEOUT)
		self.failUnlessEqual(len(data), epsize * 2 - 1)

	def testStallInPartial(self):
		epsize = self.in_ep.maxPacketSize

		written = self.dev.bulkWrite(self.out_ep.address,
				(2 * epsize - 1) * [0],
				TIMEOUT)
		self.failUnlessEqual(written, 2 * epsize - 1)

		data = self.dev.bulkRead(self.in_ep.address, epsize,
				TIMEOUT)
		self.failUnlessEqual(len(data), epsize)

		self.setHalt(self.in_ep)
		status = self.getEpStatus(self.in_ep)
		self.failUnless(status & (1 << ENDPOINT_HALT))

		data = self.dev.bulkRead(self.in_ep.address, epsize - 1,
				TIMEOUT)
		self.failUnlessEqual(len(data), 0)

		self.dev.clearHalt(self.in_ep.address)
		status = self.getEpStatus(self.in_ep)
		self.failIf(status & (1 << ENDPOINT_HALT))

		data = self.dev.bulkRead(self.in_ep.address, epsize - 1,
				TIMEOUT)
		self.failUnlessEqual(len(data), epsize - 1)

def suite():
	return unittest.makeSuite(TestCase, 'test')
