import numpy as np

# where is data stored
datadir = '~/data/sphtxt/fidnf/'
datafilebase = 'outputs_' # the base of stuff

factor = 10.0 # make distances bigger in space on Sketchfab

colors = [(255,255,0),  (255,0,0), (51,51,255), (51,51,255), (255,255,255), (0,255,0)]
# how many particle types?
# 0 = Gas : yellow
# 1 = Halo: red
# 2 = Disk: blue
# 3 = Bulge: blue
# 4 = Stars: white
# 5 = BHs: green

nsnapshots_start = 0
nsnapshots_end = 199

dt_snapshot = 0.1 # how much time between snapshots

# decimation factor
#dfac = 0.1 # 10% of particles are saved
dfac = 0.01 # 1% of particles are saved

outdir = '~/outputs_fixer/' # make this directory!

#------------------------------------------------------------------


# ply exporter modifed from Tomer Nussbaum, for questions contact either one of us, Tomer: tussbaum@gmail.com 
def write_particles_ply_dec(dst_filename, particles_x, particles_y, particles_z, color_vec, dec, nondec = []):
    f_par = open(dst_filename+'.ply','w')
    nv = len(particles_x[0][:])*len(particles_x)
    f_par.write('ply\nformat ascii 1.0\ncomment Created by vertex2ply pyScript 21122015\n'+\
                'element vertex ' + str(nv) +'\n'+\
                'property float x\n' +\
                'property float y\n' +\
                'property float z\n' +\
                'property uchar red\n' +\
                'property uchar green\n' +\
                'property uchar blue\n' +\
                'element face 0\n' +\
                'property list uchar uint vertex_indices\n' +\
                'end_header\n')
    
    #print(color_vec)
    for i in range(len(color_vec)): # this is the length of color
        for j in range(len(particles_x[i][:])):
            if (len(nondec) > 0) and (len(np.where(i == nondec)[0]) == 0): # we are decimating things AND this is a particle to be decimated
                if np.random.random() <= dec:
                    f_par.write('%.6g %.6g %.6g '  % (particles_x[i][j], particles_y[i][j], particles_z[i][j]) + str(color_vec[i][0]) + ' ' + str(color_vec[i][1]) + ' ' + str(color_vec[i][2]) + '\n')
            else:
                f_par.write('%.6g %.6g %.6g '  % (particles_x[i][j], particles_y[i][j], particles_z[i][j]) + str(color_vec[i][0]) + ' ' + str(color_vec[i][1]) + ' ' + str(color_vec[i][2]) + '\n')
    f_par.close()


#------------------------------------------------------------------

def CenterOfMass(x,y,z,m):
    CoM = np.zeros(3)
    Mtot= m.sum()
    CoM[0] = (x*m).sum()/Mtot
    CoM[1] = (y*m).sum()/Mtot
    CoM[2] = (z*m).sum()/Mtot
    return CoM


f_tfile = open(outdir+'sketchfab.timeframe','w')

for ii in range(nsnapshots_start,nsnapshots_end):

    datafile = datadir + datafilebase + str(ii).zfill(3) + '.txt'
    data = np.genfromtxt(datafile, delimiter=',')     

    # com? assume all masses are the same... which is tots not true but whatever
    # only do for the first file, otherwise its jumpy
    if ii == nsnapshots_start:
        com = CenterOfMass(data[:,1]*factor, data[:,2]*factor, data[:,3]*factor, np.zeros(len(data))+1.0)
    

    xs = []
    ys = []
    zs = []
    for i in range(0,len(colors)):
        xs.append(data[ data[:,4] == i ,1]*factor-com[0])
        ys.append(data[ data[:,4] == i ,2]*factor-com[1])
        zs.append(data[ data[:,4] == i ,3]*factor-com[2])


    # name of output ply file, will be NAME.ply
    fname = 'outply_complex' + str(ii).zfill(3)
    print('Writing ' + outdir+fname + '.ply')
    write_particles_ply_dec(outdir+fname, xs, ys, zs, colors, dfac, nondec = [4])

    # timestamp file
    f_tfile.write(str(dt_snapshot) + ' ' + fname + '.ply \n')


f_tfile.close()
    
