import os
import warnings

from pyglet.gl import *
from pyglet import image
from gfx import DisplayList


class Material(object):
	diffuse = [.8, .8, .8]
	ambient = [.2, .2, .2]
	specular = [0., 0., 0.]
	emission = [0., 0., 0.]
	shininess = 0.
	opacity = 1.
	texture = None

	def __init__(self, name=None):
		self.name = name

	def apply(self, face=GL_FRONT_AND_BACK):
		if self.texture:
			glEnable(self.texture.target)
			glBindTexture(self.texture.target, self.texture.id)
		else:
			glDisable(GL_TEXTURE_2D)

		glMaterialfv(face, GL_DIFFUSE, 
			(GLfloat * 4)(*(self.diffuse + [self.opacity])))
		glMaterialfv(face, GL_AMBIENT, 
			(GLfloat * 4)(*(self.ambient + [self.opacity])))
		glMaterialfv(face, GL_SPECULAR, 
			(GLfloat * 4)(*(self.specular + [self.opacity])))
		glMaterialfv(face, GL_EMISSION, 
			(GLfloat * 4)(*(self.emission + [self.opacity])))


class MaterialGroup(object):
	def __init__(self, material):
		self.material = material

		# Interleaved array of floats in GL_T2F_N3F_V3F format
		self.vertices = []
		self.v2n = {}
		self.edge_flags = None
		self.array = None


class Mesh(object):
	def __init__(self, model, name):
		self.name = name
		self.groups = []
		self.model = model
		self.display_list = None

	def draw(self):
		glPushClientAttrib(GL_CLIENT_VERTEX_ARRAY_BIT)
		glPushAttrib(GL_CURRENT_BIT | GL_ENABLE_BIT | GL_LIGHTING_BIT)
		for group in self.groups:
			if group.material is not None:
				group.material.apply()
			if group.array is None:
				group.array = (GLfloat * len(group.vertices))(*group.vertices)
				group.triangles = len(group.vertices) / 8
			if group.edge_flags is None and group.triangles:
				group.edge_flags = (GLboolean * group.triangles)()
				v2n = group.v2n
				# Iterate the triangle edges
				# any sharp edges are highlighted by setting the edge flag
				edge = 0
				for tri in range(group.triangles / 3):
					tx = tri * 24 # triangle start index
					v1 = tuple(group.vertices[tx+5:tx+8])
					v2 = tuple(group.vertices[tx+13:tx+16])
					v3 = tuple(group.vertices[tx+21:tx+24])
					for spt, ept in [(v1, v2), (v2, v3), (v3, v1)]:
						group.edge_flags[edge] = 0
						if spt in v2n and ept in v2n:
							common_normals = list(v2n[spt].intersection(v2n[ept]))
							if len(common_normals) > 1:
								nx, ny, nz = common_normals[0]
								for ox, oy, oz in common_normals[1:]:
									d = nx*ox + ny*oy + nz*oz
									if abs(d) < self.model.edge_threshold:
										group.edge_flags[edge] = 1
										break
							elif len(common_normals) == 1:
								group.edge_flags[edge] = 1
						edge += 1
			glInterleavedArrays(GL_T2F_N3F_V3F, 0, group.array)
			glEdgeFlagPointer(0, group.edge_flags)
			glEnableClientState(GL_EDGE_FLAG_ARRAY)
			glDrawArrays(GL_TRIANGLES, 0, group.triangles)
		glPopAttrib()
		glPopClientAttrib()

	def draw_wireframe(self):
		glPushClientAttrib(GL_CLIENT_VERTEX_ARRAY_BIT)
		glPushAttrib(GL_CURRENT_BIT | GL_ENABLE_BIT | GL_LIGHTING_BIT)
		glEnable(GL_CULL_FACE)
		glCullFace(GL_BACK)
		for group in self.groups:
			if group.material is not None:
				group.material.apply()
			if group.array is None:
				group.array = (GLfloat * len(group.vertices))(*group.vertices)
				group.triangles = len(group.vertices) / 8
			v = group.vertices
			for i in range(0, len(v), 24):
				glBegin(GL_LINE_LOOP)
				glVertex3f(v[i + 5], v[i + 6], v[i + 7])
				glVertex3f(v[i + 13], v[i + 14], v[i + 15])
				glVertex3f(v[i + 21], v[i + 22], v[i + 23])
				glEnd()
		glPopAttrib()
		glPopClientAttrib()

	def compile(self):
		"""Compile the mesh into a display list, which will be used
		when draw() is called thereafter
		"""
		if self.display_list is None:
			self.display_list = DisplayList(self.draw)
			self.draw = self.display_list.execute


class Model:
	"""3D model loaded from a Wavefront .obj file"""

	def __init__(self, filename, file=None, path=None, scale=1.0, edge_threshold=.99, compile=True):
		self.materials = {}
		self.meshes = {}		# Name mapping
		self.mesh_list = []	 # Also includes anonymous meshes

		if file is None:
			file = open(filename, 'r')

		if path is None:
			path = os.path.dirname(os.path.abspath(filename))
		self.path = path
		self.scale = float(scale)
		self.edge_threshold = float(edge_threshold)

		mesh = None
		group = None
		material = None

		vertices = [[0., 0., 0.]]
		normals = [[0., 0., 0.]]
		tex_coords = [[0., 0.]]

		for line in open(filename, "r"):
			if line.startswith('#'): 
				continue
			values = line.split()
			if not values: 
				continue

			if values[0] == 'v':
				vertices.append([float(v)*scale for v in values[1:4]])
			elif values[0] == 'vn':
				normals.append(map(float, values[1:4]))
			elif values[0] == 'vt':
				tex_coords.append(map(float, values[1:3]))
			elif 0 and values[0] == 'mtllib':
				self.load_material_library(values[1])
			elif 0 and values[0] in ('usemtl', 'usemat'):
				material = self.materials.get(values[1], None)
				if material is None:
					warnings.warn('Unknown material: %s' % values[1])
				if mesh is not None:
					group = MaterialGroup(material)
					mesh.groups.append(group)
			elif values[0] == 'o':
				mesh = Mesh(self, values[1])
				self.meshes[mesh.name] = mesh
				self.mesh_list.append(mesh)
				group = None
			elif values[0] == 'f':
				if mesh is None:
					mesh = Mesh(self, '')
					self.mesh_list.append(mesh)
				if material is None:
					material = Material()
				if group is None:
					group = MaterialGroup(material)
					mesh.groups.append(group)

				# For fan triangulation, remember first and latest vertices
				v1 = None
				vlast = None
				points = []
				for i, v in enumerate(values[1:]):
					v_index, t_index, n_index = \
						(map(int, [j or 0 for j in v.split('/')]) + [0, 0])[:3]
					if v_index < 0:
						v_index += len(vertices) - 1
					if t_index < 0:
						t_index += len(tex_coords) - 1
					if n_index < 0:
						n_index += len(normals) - 1
					v = tuple(vertices[v_index])
					if v not in group.v2n:
						group.v2n[v] = set()
					group.v2n[v].add(tuple(normals[n_index]))
					vertex = tex_coords[t_index] + \
							 normals[n_index] + \
							 vertices[v_index] 

					if i >= 3:
						# Triangulate
						group.vertices += v1 + vlast
					group.vertices += vertex

					if i == 0:
						v1 = vertex
					vlast = vertex
		
		if compile:
			for mesh in self.mesh_list:
				mesh.compile()
			self.display_lists = [m.display_list for m in self.mesh_list]
		else:
			self.display_lists = None

					
	def open_material_file(self, filename):
		'''Override for loading from archive/network etc.'''
		return open(os.path.join(self.path, filename), 'r')

	def load_material_library(self, filename):
		material = None
		file = self.open_material_file(filename)

		for line in file:
			if line.startswith('#'):
				continue
			values = line.split()
			if not values:
				continue

			if values[0] == 'newmtl':
				material = Material(values[1])
				self.materials[material.name] = material
			elif material is None:
				warnings.warn('Expected "newmtl" in %s' % filename)
				continue

			try:
				if values[0] == 'Kd':
					material.diffuse = map(float, values[1:])
				elif values[0] == 'Ka':
					material.ambient = map(float, values[1:])
				elif values[0] == 'Ks':
					material.specular = map(float, values[1:])
				elif values[0] == 'Ke':
					material.emissive = map(float, values[1:])
				elif values[0] == 'Ns':
					material.shininess = float(values[1])
				elif values[0] == 'd':
					material.opacity = float(values[1])
				elif values[0] == 'map_Kd':
					try:
						material.texture = image.load(values[1]).texture
					except image.ImageDecodeException:
						warnings.warn('Could not load texture %s' % values[1])
			except:
				warnings.warn('Parse error in %s.' % filename)

	def draw(self):
		if self.display_lists is not None:
			DisplayList.execute_many(*self.display_lists)
		else:
			for mesh in self.mesh_list:
				mesh.draw()

	def draw_wireframe(self):
		for mesh in self.mesh_list:
			mesh.draw_wireframe()
