threading模块学习-生产者与消费者

threading模块学习--生产者与消费者

翻阅Python的threading模块API文档,觉得既有Java的影子(如threading.Thread类及其start/run方法),又有pthread的影子(如threading.Lock/threading.Condition)。总得来说简单易掌握,呵呵,随手写个大学OS课程里面“生产者与消费者”的Test。

相关类及方法就不多解释了,直接参考threading的API文档吧。

#! /usr/bin/env python
# coding=utf-8

import threading
import time

class Container(object):
	def __init__(self, size = 10):
		self.size = size
		self.container = []
		self.condition = threading.Condition()
	
	def isEmpty(self):
		self.condition.acquire()
		size = len(self.container)
		self.condition.release()
		return size == 0

	def isFull(self):
		self.condition.acquire()
		size = len(self.container)
		self.condition.release()
		return size == self.size

	def add(self, obj):
		self.condition.acquire()
		while len(self.container) >= self.size:
			self.condition.wait()
		self.container.append(obj)
		self.condition.notifyAll()
		self.condition.release()
	
	def get(self):
		self.condition.acquire()
		while len(self.container) == 0:
			self.condition.wait()
		obj = self.container.pop(0)
		self.condition.release()
		return obj

class Producer(threading.Thread):
	def __init__(self, container, size, group = None, target = None, name = None, args = (), kwargs = {}):
		threading.Thread.__init__(self, group, target, name, args, kwargs)
		self.container = container
		self.size = size
		self.count = 0
	
	def produce(self):
		cur = threading.currentThread()
		# s = str(cur.ident) + " " +  cur.getName() +  " (" + str(self.count) + ')'
		s = cur.getName() +  " (" + str(self.count) + ')'
		return s 

	def run(self):
		while self.count < self.size:
			if not self.container.isFull():
				obj = self.produce()
				self.container.add(obj)
				self.count += 1
			time.sleep(1)

class Customer(threading.Thread):
	def __init__(self, container, size, group = None, target = None, name = None, args = (), kwargs = {}):
		threading.Thread.__init__(self, group, target, name, args, kwargs)
		self.container = container
		self.size = size
		self.count = 0
	
	def consume(self, obj):
		cur = threading.currentThread()
		# s = str(cur.ident) + ' ' + cur.getName() + ' (' + str(self.count) + '): ' + obj
		s = cur.getName() + ' (' + str(self.count) + '): ' + obj
		print s 
	
	def run(self):
		while self.count < self.size:
			self.consume(self.container.get())
			time.sleep(1)
			self.count += 1

def main():
	container = Container()
	p1 = Producer(container, 30, name = 'Producer 1')
	c1 = Customer(container, 10, name = 'Customer 1')
	c2 = Customer(container, 10, name = 'Customer 2')
	c3 = Customer(container, 10, name = 'Customer 3')
	c1.start()
	c2.start()
	c3.start()
	p1.start()

if __name__ == '__main__':
	main()