import scipy.interpolate as interpol
import numpy as np

frac=[0.25,0.5,0.75]
set_source("faked", "xsphabs.A*xspowerlaw.B")

par1s = [0.01,0.1,0.4,1,4,10]
par2s = [0,1,2,3,4]

B.norm=1.0

cpar1s = []
npar1s = len(par1s)
for ii in range(0, npar1s-1):
	for par in np.arange(par1s[ii],par1s[ii+1],1.*npar1s*(par1s[ii+1]-par1s[ii])/200):
		cpar1s.append(par)

cpar2s = []
npar2s = len(par2s)
for ii in range(0, npar2s-1):
	for par in np.arange(par2s[ii],par2s[ii+1],1.*npar2s*(par2s[ii+1]-par2s[ii])/200):
		cpar2s.append(par)

print "\t".join(["gridID","par1","par2","E25","E50","E75"])
print "\t".join(["S","NF","NF","N","N","N"])


arfdata = unpack_arf("example.arf")
rmfdata = unpack_rmf("example.rmf")



import pyfits as pf
import modfits as mf
import struct as st

col1=pf.Column(name='gridID', format='7A', array=np.array([' ']))
col2=pf.Column(name='par1',   format='E',  array=np.array([0.0]))
col3=pf.Column(name='par2',   format='E',  array=np.array([0.0]))
col4=pf.Column(name='E25',    format='E',  array=np.array([0.0]), unit='keV')
col5=pf.Column(name='E50',    format='E',  array=np.array([0.0]), unit='keV')
col6=pf.Column(name='E75',    format='E',  array=np.array([0.0]), unit='keV')

cols = pf.ColDefs([col1,col2,col3,col4,col5,col6])
tbhdu = pf.new_table(cols)
hdu = pf.PrimaryHDU(np.arange(1))
thdulist = pf.HDUList([hdu, tbhdu])

thdulist[1].header.update('source',"./grid.py","creator")
thdulist[1].header.update('frac',  "0.25,0.5,0.75","quantile fractions")
thdulist[1].header.update('model', "xsphabs.A*xspowerlaw.B","spectral models")
thdulist[1].header.update('par1',  "A.nH:0.01,0.1,0.4,1,4,10:200","par1")
thdulist[1].header.update('par2',  "B.PhoIndex:0,1,2,3,4:200","par2")
thdulist[1].header.update('etcp',  "B.norm:1.0","etc par")
thdulist[1].header.update('arf',   "example.arf","arf")
thdulist[1].header.update('rmf',   "example.rmf","rmf")
thdulist[1].header.update('range', "0.3,8.0","energy range")

thdulist.writeto('output.qgrid.fits',clobber=True)

fits=mf.modfits('output.qgrid.fits')
fits.rewind(1)


for idx, par1 in enumerate(par1s):
	for par2 in cpar2s:
		A.nH=par1
		B.PhoIndex=par2
		fake_pha("faked", arf=arfdata, rmf=rmfdata, exposure=1.e8, grouped=False)
		data=get_data("faked")
		x=data.get_x()
		y=data.counts
		w=((x >= 0.3) & (x<= 8.0)).nonzero()[0]
		x=x[w]
		y=y[w]
		acy = np.add.accumulate(y)
		acy = acy/max(acy)
		lin = interpol.interp1d(acy,x)
		qtil=lin(frac) 
		fits.append("gx%-5d" % idx + st.pack('>fffff',par1,par2,qtil[0],qtil[1],qtil[2]))



for idx, par2 in enumerate(par2s):
	for par1 in cpar1s:
		A.nH=par1
		B.PhoIndex=par2
		fake_pha("faked", arf=arfdata, rmf=rmfdata, exposure=1.e8, grouped=False)
		data=get_data("faked")
		x=data.get_x()
		y=data.counts
		w=((x >= 0.3) & (x<= 8.0)).nonzero()[0]
		x=x[w]
		y=y[w]
		acy = np.add.accumulate(y)
		acy = acy/max(acy)
		lin = interpol.interp1d(acy,x)
		qtil=lin(frac) 
		fits.append("gy%-5d" % idx  + st.pack('>fffff',par1,par2,qtil[0],qtil[1],qtil[2]))

fits.close()