from py3shape.output2 import MPIDatabaseOutput from py3shape.analyze_meds2 import I3MedsAnalysis import psycopg2 import argparse import mpi4py.MPI import numpy as np import meds import multiprocessing import re description = "MPI launch of analyze_meds2." parser = argparse.ArgumentParser(description=description, add_help=True) parser.add_argument('meds_list', type=str, help='the input meds FITS file and corresponding catalogs') parser.add_argument('cat', type=str, help='The DB table with the precut catalog in it') parser.add_argument('ini', type=str, help='Im3shape ini file') parser.add_argument('output', type=str, help='Database table name') parser.add_argument('--great', action='store_true', help='GREAT-DES tile names') parser.add_argument('--chunk', type=int, default=50, help='Size of job chunks') parser.add_argument('--limit', type=int, default=0, help='Maximum number of objects to try') parser.add_argument('--profile', type=str, default="", help='Filename root for profiling') parser.add_argument('--debug-mpi', action='store_true', help='Print out lots of MPI messages') parser.add_argument('--fatal-errors', action='store_true', help='Number of entries to run') END_OF_JOBS = "END_OF_JOBS_SENTINEL" END_OF_TASK = ( "END_OF_TASK", "END_OF_TASK") great_des_pattern=re.compile(r"nbc(.)+\.meds\.([0-9][0-9][0-9])\.g([0-9][0-9])\.fits") tile_pattern = re.compile(r'DES[0-9][0-9][0-9][0-9][+-][0-9][0-9][0-9][0-9]') tile_band_pattern = re.compile(r'DES[0-9][0-9][0-9][0-9][+-][0-9][0-9][0-9][0-9][_-][ugrizy]') def find_tile_band_from_filename(filename): return tile_band_pattern.search(filename).group() def find_tile_and_band_from_filename(filename, great=False): if great: #set band to r #pretend the nbc2.meds.000.g00.fits band='r' m=great_des_pattern.search(filename) tile='GD'+('_'.join(m.groups())) #e.g GD2_000_00 for nbc2.meds.000.g00.fits else: tile_band=find_tile_band_from_filename(filename) tile=tile_band[:12] band=tile_band[13] return tile, band def select_objects(meds_file, cat_name, output, great): #We want to remove any objects from the list that have #already been analyzed in the given table. m = meds.MEDS(meds_file) tile,band = find_tile_and_band_from_filename(meds_file, great) #Get the catalog of objects to run on in this tile #from the selected catalog cursor = output.connection.cursor() sql = "SELECT iobj from {cat_name} where tile='{tile}'".format(cat_name=cat_name, tile=tile) cursor.execute(sql) iobjs = np.array([x[0] for x in cursor.fetchall()]) cursor.close() coadd_objects_id = m['id'][iobjs].tolist() #check if these values are in the DB already. cursor = output.connection.cursor() #Thanks, user "fog" from StackOverflow: sql = """SELECT A.x = ANY(SELECT identifier FROM {table_name}) FROM (SELECT * FROM unnest(%s) x) A""".format(table_name=output.main_name) try: cursor.execute(sql, (coadd_objects_id,)) except psycopg2.ProgrammingError as error: if 'relation' in error.message and 'does not exist' in error.message: #new catalog output.connection.rollback() cursor.close() return iobjs else: raise is_new = np.array([not x[0] for x in cursor.fetchall()]) cursor.close() output.connection.commit() #Return the list of objects not analyzed already return iobjs[is_new] def make_tasks(args, output): tasks = [] for line in open(args.meds_list): #ignore empty or commented lines line=line.strip() if (not line) or line.startswith("#"): continue meds_file=line #Hijack the connection to the DB to check for objects already in the DB iobjs = select_objects(meds_file, args.cat, output, args.great) print "Found %d objects in %s" % (len(iobjs), meds_file) if args.limit: iobjs=iobjs[:args.limit] nobj = len(iobjs) if nobj==0: continue nchunk = nobj//args.chunk if nchunk==0: nchunk=1 chunks = np.array_split(iobjs, nchunk) for chunk in chunks: task = (meds_file, chunk) tasks.append(task) return tasks def slave(comm, args): analysis = I3MedsAnalysis(args) #slave output just forwards to master output = MPIDatabaseOutput(comm, None) args.output = output rank = comm.Get_rank() #main program loop while True: #get job from master if args.debug_mpi: print "Proc %d waiting for job"%rank task = comm.recv(source=0) #check if this is the end of jobs sentinel if task==END_OF_JOBS: break #otherwise, we have been sent the name of the #meds file and a lsit of objects to run un meds_file, iobjs = task analysis.main(meds_file, iobjs, output, args.fatal_errors) if args.debug_mpi: print "Proc %d task complete." % rank comm.send(END_OF_TASK, dest=0) def master(comm, args): #Create the output object w output = MPIDatabaseOutput(comm, args.output) size = comm.Get_size() #Make a list of the jobs to do tasks = make_tasks(args, output) n_tasks = len(tasks) completed_tasks = 0 launched_tasks = 0 print "Have %d total tasks for %d procs" % (n_tasks, size-1) #Send out first batch of tasks, one per proc initial_batch_size = min(n_tasks, size-1) for i in xrange(initial_batch_size): task = tasks[i] comm.send(task, dest=i+1) print "Task %d/%d for proc %d: %s" % (i+1, n_tasks, i+1, task) launched_tasks+=1 #In small jobs not every process will get a task at all #particularly if we are just finishing up. #in that case end the unwanted jobs immediately if initial_batch_size