"""
This script, which is a little bit hackish, is designed to take model definition python files
and produce a large amount of associated C code for each model requested.

This is to simplify the process of adding a new model - all the "glue" code, such as reading parameters
from ini files, or applying a prior that the parameters are within the specified ranges.

The script takes in model names, and the template files i3_model.template.[ch], 
looks for models/model_name.py, models/model_name.c, models/model_name.h, and generates:
	models/model_name_definition.h files
	i3_model.c
	i3_model.h
	i3_options.[ini|c|h], using i3_build_models.py
	
The process by which the i3_model.[ch] files are generated from the template is:
	read i3_model.template.[ch]
	look for the text #python something in it
	replace with text selected in the ModelFile.to_files method
	... which is in turn generated in the Model.* methods.
	
"""

import sys
import re
import optparse
import os
sys.path.append("./models/")


c_format_codes={float:'% e',int:'% d',bool:'%d'}
c_asci_to_num_converters = {float:'atof',int:'atoi',bool:'atoi'}


class Model(object):
	construction_template = """
	if (!strcmp(name,"{name}")) {{
		model->nparam=i3_{name}_nparam;
		model->likelihood = (likelihood_function) &i3_{name}_likelihood;
		model->prior = (prior_function) {prior};
		model->propose = (proposal_function) &{proposal};
		model->posterior_derivative = (derivative_function) {derivative};
		model->start = (start_position_function) {start_function};
		model->map_physical = (map_physical_function) {map_physical_function};
		/* Set up array describing whether parameters are fixed. */
		model->param_fixed = malloc(sizeof(bool)*model->nparam);
		{parameter_fix_lines}
		/*  Set up array of parameter types */
		model->param_type = malloc(sizeof(i3_parameter_type)*model->nparam);
		{parameter_type_lines}
		
		
	}}
	"""
	struct_template = """typedef struct i3_{name}_parameter_set{{
		{declarations}		

	}} i3_{name}_parameter_set;

	#define i3_{name}_nparam {nparam}
			"""

	
	def __init__(self,model_name):
		module=__import__(model_name)
		self.name=module.name
		self.param_names=module.parameters
		self.param_types=module.types
		self.proposal=module.proposal
		self.param_defaults=module.starts
		self.nparam=len(self.param_names)
		try:
			self.help=module.help
		except AttributeError:
			self.help={}
		try:
			self.prior="&" + module.prior
		except AttributeError:
			self.prior='NULL'
		try:
			self.derivative='&'+module.derivative
		except:
			self.derivative='NULL'	
		try:
			self.start_function='&'+module.start_function
		except:
			self.start_function='NULL'	
		try:
			self.map_physical_function = '&'+module.map_physical_function
		except:
			self.map_physical_function = 'NULL'
		try:
			self.get_model_image_function= module.get_model_image_function
		except:
			self.get_model_image_function='NULL'
		
	def parameter_fix_lines(self):
		model_name=self.name
		for i,param in enumerate(self.param_names):
			help=self.help.get(param,"")
			if help:
				help = '/* ' + help + '*/'
			
			yield "model->param_fixed[{i}] = options->{model_name}_{param}_fixed;  {help}".format(i=i,param=param,model_name=model_name,help=help)
			
	def param_enum(self,param_type):
		return {
			float:"i3_parameter_type_flt",
			int:"i3_parameter_type_int",
			bool:"i3_parameter_type_bool",
			list:"i3_parameter_type_array",
		}[param_type]
	def cast_type(self, param_type):
		return {
			float:"float",
			int:"int",
			bool:"bool",
		}[param_type]
	def parameter_type_lines(self):
		model_name=self.name
		for i,param in enumerate(self.param_names):
			param_type=self.param_types[param]
			param_code=self.param_enum(param_type)
			yield "model->param_type[{i}] = {param_code};".format(i=i,param_code=param_code)
	def creation_code(self):
		name=self.name
		proposal=self.proposal
		parameter_fix_lines = "\n\t\t".join(self.parameter_fix_lines())
		parameter_type_lines = "\n\t\t".join(self.parameter_type_lines())
		return self.construction_template.format(name=name,proposal=proposal,parameter_fix_lines=parameter_fix_lines,parameter_type_lines=parameter_type_lines,prior=self.prior,derivative=self.derivative, start_function=self.start_function, map_physical_function=self.map_physical_function)

	def options(self):
		for p in self.param_names:
			yield (p,self.param_types[p],self.param_defaults[p],'Initial value of parameter {0} in model {1}'.format(p,self.name))

	def struct(self):
		declarations = "\n\t".join(self.declarations())
		name=self.name
		nparam=len(self.param_names)
		return self.struct_template.format(name=name,declarations=declarations,nparam=nparam)
	def test_model_name_line(self):
		return 'if (!strcmp(model->name,"{name}")){{'.format(name=self.name)  
	def offset_code(self):
		lines=[]
		lines.append(  self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set P;".format(name=self.name)        )
		lines.append(  "\t\tmodel->nbytes = sizeof(P);"                               )
		for i,param in enumerate(self.param_names):
			lines.append( "\t\tmodel->byte_offsets[{i}] = ((off_type)&(P.{param_name})) - (off_type)&P;".format(i=i,param_name=param)   )
		lines.append("\n}")
		return "\n".join(lines)
	def scale_code(self):
		lines=[]
		lines.append(  self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				length=len(self.param_defaults[param])
				lines.append("\t\tfor (int ip=0;ip<{length};ip++) P->{param}[ip]*=scale;".format(param=param,length=length))
			else:
				lines.append("\t\tP->{param}*=scale;".format(param=param))
		lines.append("\n\t}")
		return "\n".join(lines)
	def extract_params_code(self,field):
		lines=[]
		lines.append( self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				length=len(self.param_defaults[param])
				lines.append("\t\tfor (int ip=0; ip<{length}; ip++) P->{param}[ip]=options->{name}_{param}_{field}[ip];".format(length=length,name=self.name,param=param,field=field))			
			else:
				lines.append("\t\tP->{param}=options->{name}_{param}_{field};".format(name=self.name,param=param,field=field))
		lines.append("}")
		return "\n".join(lines)
	def nonzero_width_nparam_code(self):
		lines=[]
		lines.append( self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * W = (i3_{name}_parameter_set*) width;".format(name=self.name)        )
		for p, param in enumerate(self.param_names):
			lines.append("\t\tif (W->{param}!=0 && !model->param_fixed[{p}]) n+=1;".format(param=param,p=p))
		lines.append("\t}")
		return "\n".join(lines)



	def range_code(self):
		lines=[]
		lines.append( self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * range_min = (i3_{name}_parameter_set*) i3_model_option_min(model,options);".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * range_max = (i3_{name}_parameter_set*) i3_model_option_max(model,options);".format(name=self.name)        )
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				length = len(self.param_defaults[param])
				lines.append("\t\tfor (int ip=0; ip<{length}; ip++) P->{param}[ip]=range_max->{param}[ip]-range_min->{param}[ip];".format(length=length,param=param))
				
			else:
				lines.append("\t\tP->{param}=range_max->{param}-range_min->{param};".format(param=param))
		lines.append("\t\tfree(range_min);")
		lines.append("\t\tfree(range_max);")
		lines.append("}")
		return "\n".join(lines)

	def violation_code(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * min_p = (i3_{name}_parameter_set*) model->min;".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * max_p = (i3_{name}_parameter_set*) model->max;".format(name=self.name)        )
		for param in self.param_names:
			if self.param_types[param]==list:
				length = len(self.param_defaults[param])
				lines.append("\t\tfor (int ip=0; ip<{length}; ip++) if (P->{param}[ip]<min_p->{param}[ip]) {{fprintf(output,\"Violated minimum on {name} {param}[%d]: %f < %f\\n\", ip, P->{param}[ip],min_p->{param}[ip]); violated=true;}}".format(param=param,name=self.name,length=length) )
				lines.append("\t\tfor (int ip=0; ip<{length}; ip++) if (P->{param}[ip]>max_p->{param}[ip]) {{fprintf(output,\"Violated maximum on {name} {param}[%d]: %f > %f\\n\", ip, P->{param}[ip],max_p->{param}[ip]); violated=true;}}".format(param=param,name=self.name,length=length) )			
			else:
				lines.append("\t\tif (P->{param}<min_p->{param}) {{fprintf(output,\"Violated minimum on {name} {param}: %f < %f\\n\", P->{param},min_p->{param}); violated=true;}}".format(param=param,name=self.name) )
				lines.append("\t\tif (P->{param}>max_p->{param}) {{fprintf(output,\"Violated maximum on {name} {param}: %f > %f\\n\", P->{param},max_p->{param}); violated=true;}}".format(param=param,name=self.name) )
		lines.append("\t}")
		return "\n".join(lines)
	
	def prior_code(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * min_p = (i3_{name}_parameter_set*) model->min;".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * max_p = (i3_{name}_parameter_set*) model->max;".format(name=self.name)        )
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				length = len(self.param_defaults[param])
				lines.append("\t\tfor (int ip=0; ip<{length}; ip++) if (P->{param}[ip]<min_p->{param}[ip]||P->{param}[ip]>max_p->{param}[ip]) return BAD_LIKELIHOOD;".format(param=param,length=length) )
			else:
				lines.append("\t\tif (P->{param}<min_p->{param}||P->{param}>max_p->{param}) return BAD_LIKELIHOOD;".format(param=param) )
		lines.append("\t}")
		return "\n".join(lines)
		
	def fixes_code(self):
		lines=[]
		lines.append( self.test_model_name_line() )
		for i,param in enumerate(self.param_names):
			lines.append("\t\tmodel->param_fixed[{i}]=options->{name}_{param}_fixed;".format(name=self.name,param=param,i=i))
		lines.append("}")
		return "\n".join(lines)
	def any_nan_code(self):
		lines = []
		lines.append( "\t"+self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		
		for param in self.param_names:
			param_type=self.param_types[param]			
			if param_type==float: lines.append("\t\tif (!(P->{param}==P->{param})) {{return true;}}".format(param=param))
		lines.append("\t}")
		return "\n".join(lines)

	def parameter_set_by_name_code(self):
		lines=[]
		lines.append( "\t"+self.test_model_name_line() )
		lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
		for param in self.param_names:
			param_type=self.param_types[param]			
			if param_type==list:
				lines.append("\t\tif (0==strcmp(name,\"{param}\")) {{fprintf(stderr, \"Cannot set array params by name yet.\\n\"); return 1;}}".format(param=param))
			else:
				lines.append("\t\tif (0==strcmp(name,\"{param}\")) {{P->{param}=({cast})value; return 0;}}".format(param=param, cast=self.cast_type(param_type)))
		lines.append("}")
		return '\n'.join(lines)
	def ellipticity_code(self):
		lines=[]
		lines.append( "\t"+self.test_model_name_line() )
		if "e1" not in self.param_names or "e2" not in self.param_names:
			lines.append("\t\tI3_FATAL(\"Tried to extract e1 and e2 from parameter set for model '{model_name}', which has no e1 or e2 component\",1);".format(model_name=self.name))
		else:
			lines.append(  "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)        )
			lines.append("\t\t * e1 = P->e1;")
			lines.append("\t\t * e2 = P->e1;")
		lines.append("\t}")
		return "\n".join(lines)
	
	def declarations(self):
		for param_name in self.param_names:
			param_type=self.param_types[param_name]
			if param_type==list:
				yield self.declare_array(param_name, float, len(self.param_defaults[param_name]))
			else:
				yield self.declare(param_name,param_type)
			
	def declare_array(self,name,param_type,n):
		help=self.help.get(name,"")
		if help:
			help = '/* ' + help + ' */'
		declaration={
			float:"i3_flt",
			int:"int",
			bool:"bool",
		}[param_type]
		return "{declaration} {name}[{n}];\t\t{help}".format(name=name,declaration=declaration, help=help, n=n)
		
	def declare_pointer(self,name,param_type):
		help=self.help.get(name,"")
		if help:
			help = '/* ' + help + ' */'
		declaration={
			float:"i3_flt",
			int:"int",
			bool:"bool",
		}[param_type]
		return "{declaration} * {name};\t\t{help}".format(name=name,declaration=declaration, help=help)
		
	def declare(self,name,param_type):
		help=self.help.get(name,"")
		if help:
			help = '/* ' + help + ' */'
		declaration={
			float:"i3_flt",
			int:"int",
			bool:"bool",
		}[param_type]
		return "{declaration} {name};\t\t{help}".format(name=name,declaration=declaration, help=help)
	def c_format_code(self,param):
		ptype = self.param_types[param]
		if ptype in c_format_codes:
			return c_format_codes[ptype]
		elif ptype is list:
			return "%e " * len(self.param_defaults[param])

	def parameter_handle(self,variable,param):
		ptype = self.param_types[param]
		if ptype==list:
			length = len(self.param_defaults[param])
			return ", ".join(["{variable}->{param}[{i}]".format(param=param,variable=variable,i=i) for i in xrange(length)])
		else:
			return '{variable}->{param}'.format(param=param,variable=variable)

	def parameter_string_code(self):
		lines=[]
		format_string = '\\t'.join(self.c_format_code(param) for param in self.param_names)
		parameter_handles=','.join(self.parameter_handle("mp",param) for param in self.param_names)
		line='\tif (!strcmp(model->name,"{name}")){{'.format(name=self.name)
		lines.append(line)
		line='\t\t'+"i3_{name}_parameter_set * mp = (i3_{name}_parameter_set*) p;".format(name=self.name)
		lines.append(line)
		line='\t\t'+'return snprintf(parameter_string,string_length,"{format_string}",{parameter_handles});'.format(parameter_handles=parameter_handles,format_string=format_string)
		lines.append(line)
		line="\t}"
		lines.append(line)
		return "\n".join(lines)
	def parameter_pretty_string_code(self):
		lines=[]
		format_string_parts=[]
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				length = len(self.param_defaults[param])
				codes = ", ".join("%e" for i in xrange(length))
				part="{param}={codes}".format(param=param,codes=codes)
			else:
				code=self.c_format_code(param)
				part="{param}={code}".format(param=param,code=code)
			format_string_parts.append(part)
		format_string='\\n'.join(format_string_parts)

		parameter_handles=','.join(self.parameter_handle("mp",param) for param in self.param_names)
		line='\tif (!strcmp(model->name,"{name}")){{'.format(name=self.name)
		lines.append(line)
		line='\t\t'+"i3_{name}_parameter_set * mp = (i3_{name}_parameter_set*) p;".format(name=self.name)
		lines.append(line)
		line='\t\t'+'return snprintf(parameter_string,string_length,"{format_string}",{parameter_handles});'.format(parameter_handles=parameter_handles,format_string=format_string)
		lines.append(line)
		line="\t}"
		lines.append(line)
		return "\n".join(lines)
		
	def perturb_code(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		line = "\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) p;".format(name=self.name)
		lines.append(line)
		line = "\t\ti3_{name}_parameter_set * min_p = (i3_{name}_parameter_set*)model->min;".format(name=self.name)
		lines.append(line)
		line = "\t\ti3_{name}_parameter_set * max_p = (i3_{name}_parameter_set*)model->max;".format(name=self.name)
		lines.append(line)
		lines.append("\t\ti3_flt epsilon = 1e-6;")
		for i,param in enumerate(self.param_names):
			if self.param_types[param]==float:
				line = "\t\t if (!model->param_fixed[{i}]) {{".format(i=i)
				lines.append(line)
				line = "\t\t\ti3_flt r = (max_p->{param} - min_p->{param});".format(param=param)
				lines.append(line)
				line = "\t\t\tP->{param} += scale * r * i3_random_normal();".format(i=i,param=param)
				lines.append(line)
				line = "\t\t\tif (P->{param}<min_p->{param}) P->{param} = min_p->{param}+r*epsilon;".format(param=param)
				lines.append(line)
				line = "\t\t\tif (P->{param}>max_p->{param}) P->{param} = max_p->{param}-r*epsilon;".format(param=param)
				lines.append(line)
				line = "\t\t}"
				lines.append(line)
			elif self.param_types[param]==list:
				line = "\t\t if (!model->param_fixed[{i}]) {{".format(i=i)
				lines.append(line)
				line = "\t\t\t for (int ip=0; ip<{length}; ip++){{".format(length=len(self.param_defaults[param]))
				lines.append(line)
				line = "\t\t\t\ti3_flt r = (max_p->{param}[ip] - min_p->{param}[ip]);".format(param=param)
				lines.append(line)
				line = "\t\t\t\tP->{param}[ip] += scale * r * i3_random_normal();".format(i=i,param=param)
				lines.append(line)
				line = "\t\t\t\tif (P->{param}[ip]<min_p->{param}[ip]) P->{param}[ip] = min_p->{param}[ip]+r*epsilon;".format(param=param)
				lines.append(line)
				line = "\t\t\t\tif (P->{param}[ip]>max_p->{param}[ip]) P->{param}[ip] = max_p->{param}[ip]-r*epsilon;".format(param=param)
				lines.append(line)
				line = "\t\t\t}"
				lines.append(line)
				line = "\t\t}"
				lines.append(line)
				
				
		lines.append("\t}")
		return "\n".join(lines)
		
	def tester_code(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		lines.append(  "\t\ti3_model * {name}_model = i3_model_create(\"{name}\",options);".format(name=self.name)        )
		lines.append(  "\t\ti3_{name}_parameter_set * {name}_start = i3_model_option_starts({name}_model, options);".format(name=self.name)        )
		lines.append(  "\t\ti3_model_posterior".format(name=self.name)        )
		line="\t}"
		lines.append(line)
	def get_model_image_code(self):
		lines=[]
		line='\tif (!strcmp(model->name,"{name}")){{'.format(name=self.name)
		lines.append(line)
				#line='\t\t'+'return snprintf(parameter_string,string_length,"{format_string}",{parameter_handles});'.format(parameter_handles=parameter_handles,format_string=format_string)
		if self.get_model_image_function == 'NULL':
			line='\t\t printf("image generation function for this model ' + self.name + ' is not implenented.\\n");'
			lines.append(line)
		else:
			# line='\t\t'+"i3_{name}_parameter_set * m_params = (i3_{name}_parameter_set*) params;".format(name=self.name)
			# lines.append(line)
			line='\t\t'+ self.get_model_image_function + '(params,dataset,image_out);' 
			lines.append(line)
		line="\t}"
		lines.append(line)
		return "\n".join(lines)
	
	def read_params_from_str_code(self):
		lines=[]
		format_string_parts=[]
		#This starts at 
		index = 0
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list:
				param_length = len(self.param_defaults[param])
				conversion_fun="i3_options_parse_flt_array"
				part = "\t\tint status; {conversion_fun}(m_params->{param},parameter_strings[{index}],{param_length}, &status);".format(param=param,index=index,conversion_fun=conversion_fun,param_length=param_length)
			else:
				conversion_fun=c_asci_to_num_converters[ptype]
				part="\t\tm_params->{param} = {conversion_fun}(parameter_strings[{index}]);".format(param=param,index=index,conversion_fun=conversion_fun)
			format_string_parts.append(part)
			index = index + 1
		format_string='\n'.join(format_string_parts)
		line='\tif (!strcmp(model->name,"{name}")){{'.format(name=self.name)
		lines.append(line)
		line='\t\t'+"i3_{name}_parameter_set * m_params = (i3_{name}_parameter_set*) params;".format(name=self.name)
		lines.append(line)
		lines.append(format_string)
		line="\t}"
		lines.append(line)
		return "\n".join(lines)
		
	def header_line_code(self):
		lines = []
		lines.append("\t"+self.test_model_name_line())
		s = '              '.join(self.param_names)
		line = '\t\treturn "{names}";'.format(names=s)
		lines.append(line)
		line="\t}"
		lines.append(line)
		return "\n".join(lines)
	def posterior_derivative_approx_code(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		line="\t\ti3_{name}_parameter_set * P0 = (i3_{name}_parameter_set*) p;".format(name=self.name)
		lines.append(line)
		line="\t\ti3_{name}_parameter_set * Pprime = (i3_{name}_parameter_set*) pprime;".format(name=self.name)
		lines.append(line)
		line="\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*)malloc(model->nbytes);".format(name=self.name)
		lines.append(line)
		line="\t\ti3_flt L0 = i3_model_posterior(model,model_image,P0,data_set);"
		lines.append(line)
		for param in self.param_names:
			if self.param_types[param]!=float: continue
			line="\t\ti3_model_copy_parameters(model,P,P0);"
			lines.append(line)
			line="\t\tP->{param} += epsilon;".format(param=param)
			lines.append(line)
			line="\t\tPprime->{param} = (i3_model_posterior(model,model_image,P,data_set) - L0)/epsilon;".format(param=param)
			lines.append(line)
		lines.append(line)
		line="\t\tfree(P);"
		lines.append(line)
		line="\t*like=L0;"
		lines.append(line)
		line="\t}"
		lines.append(line)
		
		return '\n'.join(lines)
	def py_struct(self):
		lines = []
		ctype_codes = {int:"ctypes.c_int", float:"c_flt", bool:"ctypes.c_bool"}
		lines.append("class {name}Params(ctypes.Structure):".format(name=self.name.title()))
		lines.append("\t_fields_ = [")
		for param in self.param_names:
			ptype=self.param_types[param]
			if ptype==list: return "\n"
			lines.append("\t\t(\"{param}\",{ctype}),".format(param=param, ctype=ctype_codes[ptype]))
		lines.append("\t]")
		lines.append("\tdef __str__(self):")
		lines.append("\t\treturn '\\n'.join(['% 20s = % .5g'%(name[0],getattr(self,name[0])) for name in  self._fields_])")
		return "\n".join(lines)

	def cpython(self):
		lines=[]
		lines.append("\t"+self.test_model_name_line())
		line="\t\ti3_{name}_parameter_set * P = (i3_{name}_parameter_set*) params;".format(name=self.name)
		lines.append(line)
		for param in self.param_names:
			line = "int PyDict_SetItemString(dict, \"{param}\", PyObject *val);"
		line = "\t}"
		lines.append(line)
		return "\n".join(lines)



class ModelFile(object):
	
	
	def __init__(self,models,template_dir):
		self.models=models[:]
		self.code_template = os.path.join(template_dir,"i3_model.template.c")
		self.header_template = os.path.join(template_dir,"i3_model.template.h")
		self.definition_template = os.path.join(template_dir,"i3_definition.template.h")
		self.tester_template = os.path.join(template_dir,"i3_model_test.template.c")
		self.cpython_template = os.path.join(template_dir,"i3_python.template.c")
		self.python_template = os.path.join(template_dir,"im3shape.template.py")
		
	def load_template(self,filename):
		template=open(filename).read().replace('{','{{').replace('}','}}')
		python_pattern=re.compile(r'(#python (\S*))')
		for full_match,name in re.findall(python_pattern,template):
			template=template.replace(full_match,'{{{name}}}'.format(name=name))
		return template
		
	def to_tester(self,filename):
		template = self.load_template(self.tester_template)
		tester_code = "\n\telse ".join(model.tester_code() for model in self.models)
		open(filename,'w').write(template.format(test_model_images=tester_code))
		
		
	def to_files(self,c_filename, h_filename, dirname):
		
		# do the .c template
		c_template = self.load_template(self.code_template)
		creation_code = "\n\telse ".join(model.creation_code() for model in self.models)
		offset_code = "\n\telse ".join(model.offset_code() for model in self.models)
		fixes_code="\n\telse ".join(model.fixes_code() for model in self.models)
		widths_code="\n\telse ".join(model.extract_params_code('width') for model in self.models)
		nonzero_width_nparam_code="\n\telse ".join(model.nonzero_width_nparam_code() for model in self.models)
		starts_code="\n\telse ".join(model.extract_params_code('start') for model in self.models)
		min_code="\n\telse ".join(model.extract_params_code('min') for model in self.models)
		max_code="\n\telse ".join(model.extract_params_code('max') for model in self.models)
		range_code="\n\telse ".join(model.range_code() for model in self.models)
		scale_code="\n\telse ".join(model.scale_code() for model in self.models)
		string_code="\n".join(model.parameter_string_code() for model in self.models)
		prior_code="\n".join(model.prior_code() for model in self.models)
		prior_violation= "\n".join(model.violation_code() for model in self.models)
		pretty_code="\n".join(model.parameter_pretty_string_code() for model in self.models)
		ellipticity_code="\n".join(model.ellipticity_code() for model in self.models)
		parameter_set_by_name_code = "\n".join(model.parameter_set_by_name_code() for model in self.models)
		get_model_image_code = "\n".join(model.get_model_image_code() for model in self.models)
		read_params_from_str_code = "\n".join(model.read_params_from_str_code() for model in self.models)
		perturb_code = '\n'.join(model.perturb_code() for model in self.models)
		header_line_code = '\n'.join(model.header_line_code() for model in self.models)
		posterior_derivative_approx = '\n'.join(model.posterior_derivative_approx_code() for model in self.models)
		any_nan_code = '\n'.join(model.any_nan_code() for model in self.models)
		open(c_filename,'w').write(c_template.format(model_creation=creation_code,
								setup_offsets=offset_code,
								starts_code=starts_code,
								widths_code=widths_code,
								min_code=min_code,
								max_code=max_code,
								scale_code=scale_code,
								range_code=range_code,
								fixes_code=fixes_code,
								parameter_string=string_code,
								pretty_parameter_string=pretty_code,
								prior_code=prior_code,
								prior_violations=prior_violation,
								extract_ellipticity=ellipticity_code,
								get_model_image_string = get_model_image_code,
								perturb_code = perturb_code,
								header_line_code = header_line_code,
								read_params_from_str_string = read_params_from_str_code,
								posterior_derivative_approx=posterior_derivative_approx,
								any_nan_code=any_nan_code,
								nonzero_width_nparam_code=nonzero_width_nparam_code,
								parameter_set_by_name_code=parameter_set_by_name_code,
							))
		
		# do the .h template
		h_template = self.load_template(self.header_template)
		includes = ['#include "{dirname}/i3_{name}.h"'.format(dirname=dirname,name=model.name) for model in self.models]
		include_text = '\n'.join(includes)
		open(h_filename,'w').write(h_template.format(includes=include_text))

		
	def to_headers(self,dirname):
		template = self.load_template(self.definition_template)
		for model in self.models:
			guard = '_H_I3_' + model.name.upper() + '_DEFINITION'
			setname = 'i3_' + model.name + '_parameter_set'
			declarations = '\n\t'.join(model.declarations())
			nparam = 'i3_' + model.name + '_nparam ' + str(model.nparam)
			avail_macro = "HAVE_{0}_MODEL".format(model.name.upper())
			output = template.format(header_guard=guard,parameter_set_name=setname,declarations=declarations,nparam=nparam,
				avail_macro=avail_macro
				)
			filename=os.path.join(dirname,'i3_'+model.name + '_definition.h')
			open(filename,'w').write(output)
			
		
		
def main(cfile,hfile,testname,name_roots,dirname,template_dir):
	names=['i3_{0}'.format(root) for root in name_roots]
	models=[Model(name) for name in names]
	model_file=ModelFile(models,template_dir)
	model_file.to_files(cfile,hfile,dirname)
	model_file.to_headers(dirname)
#	model_file.to_tester(testname)

parser=optparse.OptionParser(usage='Usage: %prog [options] model1 model2')
parser.add_option('-d','--dir',dest='dirname',action='store',type='str',default='models',help='Set the directory to look for models [models/]')
parser.add_option('-c','--output',dest='outfile',action='store',type='str',default='i3_model.c',help='Set the collected model C file to generate [i3_model.c]')
parser.add_option('-r','--header',dest='header',action='store',type='str',default='i3_model.h',help='Set the collected model header file to generate [i3_model.h]')
parser.add_option('-t','--tester',dest='tester',action='store',type='str',default='i3_model_test.c',help='Set the collected model test program file to generate [i3_model_test.c]')
parser.add_option('-T','--template_dir',dest='template_dir',action='store',type='str',default='tools',help='Set the directory that contains the template files [tools]')


if __name__=="__main__":
	options,models=parser.parse_args()
	if not models:
		parser.print_help(sys.stderr)
		sys.exit(1)
	cfile=options.outfile
	hfile=options.header
	tfile=options.tester
	dirname=options.dirname
	template_dir=options.template_dir
	main(cfile,hfile,tfile,models,dirname,template_dir)