#!/usr/bin/python3 import argparse import os import torch import numpy as np import torch.nn as nn import torch.nn.functional as F import warnings import pandas as pd from tqdm import tqdm warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=UserWarning) parser = argparse.ArgumentParser(description='Generate hits using a pre-trained CDC-GAN', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--model', type=str, help='Path to the directory containing the saved model (.pt) and output.log file') parser.add_argument('--job-id', type=int, help='Alternative to --model. ID of the job from which to generate (generate.py will look for a directory called "output_$(JOB_ID)")') parser.add_argument('--n-hits', type=int, help='Number of hits to generate', required=True) parser.add_argument('--epoch', type=int, help='Epoch at which to evaluate', required=False) parser.add_argument('--batch-size', type=int, help=('Number of sequences to generate at a time.' ' Increase this for speed, lower it if the memory usage is too high.'), default=8) parser.add_argument('--output', type=str, help='Name of the output file', default=None) parser.add_argument('--format', type=str, help='Output data format', choices=['csv', 'hdf'], default='csv') parser.add_argument('--seed', type=int, default=13371337, help='RNG seed') args = parser.parse_args() if args.job_id is not None: job_dir = ('output_%d/' % args.job_id).rstrip('/') elif args.model is not None: job_dir = (args.model + '/').rstrip('/') else: parser.print_help() print('\nError: neither --job-id nor --model was specified. Please provide at least one.') exit(1) # Retrieve parameters from training log file log_file = open(job_dir+'/output.log', 'rt') contents = log_file.read() log_file.close() def read_from_log(var_name): _s = contents.rfind('%s=' % (var_name)) _s = contents.find('=', _s) + 1 _e = contents.find('\n', _s) if _s == -1: print("Couldn't find '%s' in log" % var_name) exit(1) var = contents[_s:_e] return var ngf = int(read_from_log('ngf')) ndf = int(read_from_log('ndf')) seq_len = int(read_from_log('sequence_length')) latent_dims = int(read_from_log('latent_dims')) format_ext = '' if args.format == 'csv': format_ext = 'csv' elif args.format == 'hdf': format_ext = 'h5' # Determine epoch from log if not provided in args import glob def find_last_epoch(): last_epoch = 0 for save_file in glob.glob(job_dir+'/states_*.pt'): idx = int(save_file.split('/')[-1].split('_')[1].split('.')[0]) if idx > last_epoch: last_epoch = idx return last_epoch if args.epoch is None: args.epoch = find_last_epoch() print('Generating from model in %s at epoch %d' % (job_dir, args.epoch)) print(' ngf:', ngf) print(' ndf:', ndf) print(' sequence length:', seq_len) print(' latent dims:', latent_dims) torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device('cpu') if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) print('Running on GPU: %s' % (torch.cuda.get_device_name())) device = torch.device('cuda') else: print('Running on CPU') import dataset data = dataset.Data() import geom_util gu = geom_util.GeomUtil(data.get_cdc_tree()) # Initialize networks import networks gen = networks.Gen(ngf=ngf, latent_dims=latent_dims, seq_len=seq_len, n_wires=gu.n_wires).to(device) disc = networks.Disc(ndf=ndf, seq_len=seq_len, n_wires=gu.n_wires).to(device) print(' Generator params: {:,}'.format(networks.get_n_params(gen))) print(' Discriminator params: {:,}'.format(networks.get_n_params(disc))) disciminator_losses = [] generator_losses = [] gradient_penalty = [] validation_losses = [] n_epochs = 0 # Load network states def load_states(path): print('Loading GAN states from %s...' % (path)) states = torch.load(path, map_location=device) disc.load_state_dict(states['disc']) global discriminator_losses discriminator_losses = states['d_loss'] gen.load_state_dict(states['gen']) global generator_losses generator_losses = states['g_loss'] global n_epochs n_epochs = states['n_epochs'] global data data.qt = states['qt'] data.minmax = states['minmax'] global gradient_penalty if 'gradient_penalty' in states: gradient_penalty = states['gradient_penalty'] global validation_losses if 'validation_loss' in states: validation_losses = states['validation_loss'] print('OK') load_states('%s/states_%d.pt' % (job_dir, args.epoch)) # Load in the training data for comparisons print("Loading dataset") data.load() print("OK") gen.eval() def sample_fake(batch_size): noise = torch.randn((batch_size, latent_dims), device=device) p, dec_w = gen(noise) return p, dec_w # Determine how many sequences to generate depending on n_hits, seq_len and batch_size n_hits = args.n_hits n_seqs = n_hits // seq_len + 1 hit_remainder = n_hits % seq_len n_batches = n_seqs // args.batch_size batch_remainder = n_seqs % args.batch_size # Whether to print the column header to the output file print_header = True # Default name for output files if args.output is None: job_subdir = job_dir.split('/')[-1] print('JOB SUBDIR %s' % job_subdir) args.output = 'gan_%s_e%d_ndf%d_ld%d' % (job_subdir, args.epoch, ndf, latent_dims) output_path = args.output + '.' + format_ext print('Saving output to "%s"' % output_path) if os.path.exists(output_path): print("WARNING: output file '%s' already exists. Appending generated data." % (output_path)) print_header = False print('Starting hit generation...') for i in tqdm(range(n_batches)): p, w = sample_fake(args.batch_size) inv_p = data.inv_preprocess(p.permute(0,2,1).flatten(0,1)).cpu() w = torch.argmax(w, dim=1).flatten(0,1).cpu() df = pd.DataFrame( data={'edep': inv_p[:,0], 't': inv_p[:,1], 'doca': inv_p[:,2], 'wire': w}) if args.format == 'csv': df.to_csv(output_path, mode='a', index=False, header=print_header) elif args.format == 'hdf': df.to_hdf(output_path, key='hits', mode='a', append=True, complevel=3) print_header = False p, w = sample_fake(batch_remainder) p = p.permute(0,2,1).flatten(0,1)[:(batch_remainder-1)*seq_len+hit_remainder] inv_p = data.inv_preprocess(p).cpu() w = torch.argmax(w, dim=1).flatten(0,1)[:(batch_remainder-1)*seq_len+hit_remainder].cpu() df = pd.DataFrame( data={'edep': inv_p[:,0], 't': inv_p[:,1], 'doca': inv_p[:,2], 'wire': w}) if args.format == 'csv': df.to_csv(output_path, mode='a', index=False, header=print_header) elif args.format == 'hdf': df.to_hdf(output_path, key='hits', mode='a', append=True, complevel=3) print("Hits generated with model in %s at epoch %d saved to '%s'." % (job_dir, args.epoch, output_path))