# $Id$ # # Copyright (C) 2002-2006 greg Landrum and Rational Discovery LLC # # @@ All Rights Reserved @@ # This file is part of the RDKit. # The contents are covered by the terms of the BSD license # which is included in the file license.txt, found at the root # of the RDKit source tree. # """Command line tool to construct an enrichment plot from saved composite models Usage: EnrichPlot [optional args] -d dbname -t tablename Required Arguments: -d "dbName": the name of the database for screening -t "tablename": provide the name of the table with the data to be screened : file name(s) of pickled composite model(s). If the -p argument is also provided (see below), this argument is ignored. Optional Arguments: - -a "list": the list of result codes to be considered active. This will be eval'ed, so be sure that it evaluates as a list or sequence of integers. For example, -a "[1,2]" will consider activity values 1 and 2 to be active - --enrich "list": identical to the -a argument above. - --thresh: sets a threshold for the plot. If the confidence falls below this value, picking will be terminated - -H: screen only the hold out set (works only if a version of BuildComposite more recent than 1.2.2 was used). - -T: screen only the training set (works only if a version of BuildComposite more recent than 1.2.2 was used). - -S: shuffle activity values before screening - -R: randomize activity values before screening - -F *filter frac*: filters the data before training to change the distribution of activity values in the training set. *filter frac* is the fraction of the training set that should have the target value. **See note in BuildComposite help about data filtering** - -v *filter value*: filters the data before training to change the distribution of activity values in the training set. *filter value* is the target value to use in filtering. **See note in BuildComposite help about data filtering** - -p "tableName": provides the name of a db table containing the models to be screened. If you use this argument, you should also use the -N argument (below) to specify a note value. - -N "note": provides a note to be used to pull models from a db table. - --plotFile "filename": writes the data to an output text file (filename.dat) and creates a gnuplot input file (filename.gnu) to plot it - --showPlot: causes the gnuplot plot constructed using --plotFile to be displayed in gnuplot. """ from __future__ import print_function from rdkit import RDConfig import numpy import copy from rdkit.six.moves import cPickle #from rdkit.Dbase.DbConnection import DbConnect from rdkit.ML.Data import DataUtils,SplitData,Stats from rdkit.Dbase.DbConnection import DbConnect from rdkit import DataStructs from rdkit.ML import CompositeRun import sys,os,types from rdkit.six import cmp __VERSION_STRING="2.4.0" def message(msg,noRet=0,dest=sys.stderr): """ emits messages to _sys.stderr_ override this in modules which import this one to redirect output **Arguments** - msg: the string to be displayed """ if noRet: dest.write('%s '%(msg)) else: dest.write('%s\n'%(msg)) def error(msg,dest=sys.stderr): """ emits messages to _sys.stderr_ override this in modules which import this one to redirect output **Arguments** - msg: the string to be displayed """ sys.stderr.write('ERROR: %s\n'%(msg)) def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0): """ collects the results of screening an individual composite model that match a particular value **Arguments** - mdl: the composite model - descs: a list of descriptor names corresponding to the data set - data: the data set, a list of points to be screened. - picking: (Optional) a list of values that are to be collected. For examples, if you want an enrichment plot for picking the values 1 and 2, you'd having picking=[1,2]. **Returns** a list of 4-tuples containing: - the id of the point - the true result (from the data set) - the predicted result - the confidence value for the prediction """ mdl.SetInputOrder(descs) for j in range(len(mdl)): tmp = mdl.GetModel(j) if hasattr(tmp,'_trainIndices') and type(tmp._trainIndices)!=types.DictType: tis = {} if hasattr(tmp,'_trainIndices'): for v in tmp._trainIndices: tis[v]=1 tmp._trainIndices=tis res = [] if mdl.GetQuantBounds(): needsQuant = 1 else: needsQuant = 0 if not indices: indices = range(len(data)) nTrueActives=0 for i in indices: if errorEstimate: use=[] for j in range(len(mdl)): tmp = mdl.GetModel(j) if not tmp._trainIndices.get(i,0): use.append(j) else: use=None pt = data[i] pred,conf = mdl.ClassifyExample(pt,onlyModels=use) if needsQuant: pt = mdl.QuantizeActivity(pt[:]) trueRes = pt[-1] if trueRes in picking: nTrueActives+=1 if pred in picking: res.append((pt[0],trueRes,pred,conf)) return nTrueActives,res def AccumulateCounts(predictions,thresh=0,sortIt=1): """ Accumulates the data for the enrichment plot for a single model **Arguments** - predictions: a list of 3-tuples (as returned by _ScreenModels_) - thresh: a threshold for the confidence level. Anything below this threshold will not be considered - sortIt: toggles sorting on confidence levels **Returns** - a list of 3-tuples: - the id of the active picked here - num actives found so far - number of picks made so far """ if sortIt: predictions.sort(lambda x,y:cmp(y[3],x[3])) res = [] nCorrect = 0 nPts = 0 for i in range(len(predictions)): id,real,pred,conf = predictions[i] if conf > thresh: if pred == real: nCorrect += 1 nPts += 1 res.append((id,nCorrect,nPts)) return res def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1): if not hasattr(details,'plotFile') or not details.plotFile: return dataFileName = '%s.dat'%(details.plotFile) outF = open(dataFileName,'w+') i = 0 while i < len(final) and counts[i] != 0: if nModels>1: mean,sd = Stats.MeanAndDev(pickVects[i]) confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i], final[i][1]/counts[i],counts[i],confInterval)) else: outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i], final[i][1]/counts[i],counts[i])) i+=1 outF.close() plotFileName = '%s.gnu'%(details.plotFile) gnuF = open(plotFileName,'w+') gnuHdr="""# Generated by EnrichPlot.py version: %s set size square 0.7 set xr [0:] set data styl points set ylab 'Num Correct Picks' set xlab 'Num Picks' set grid set nokey set term postscript enh color solid "Helvetica" 16 set term X """%(__VERSION_STRING) print(gnuHdr, file=gnuF) if nTrueActs >0: print('set yr [0:%d]'%nTrueActs, file=gnuF) print('plot x with lines', file=gnuF) if nModels>1: everyGap = i/20 print('replot "%s" using 1:2 with lines,'%(dataFileName),end='', file=gnuF) print('"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName, everyGap), file=gnuF) else: print('replot "%s" with points'%(dataFileName), file=gnuF) gnuF.close() if hasattr(details,'showPlot') and details.showPlot: try: import os from Gnuplot import Gnuplot p = Gnuplot() #p('cd "%s"'%(os.getcwd())) p('load "%s"'%(plotFileName)) raw_input('press return to continue...\n') except: import traceback traceback.print_exc() def Usage(): """ displays a usage message and exits """ sys.stderr.write(__doc__) sys.exit(-1) if __name__=='__main__': import getopt try: args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:', ('thresh=','plotFile=','showPlot', 'pickleCol=','OOB','noSort','pickBase=', 'doROC','rocThresh=','enrich=')) except: import traceback traceback.print_exc() Usage() details = CompositeRun.CompositeRun() CompositeRun.SetDefaults(details) details.activeTgt=[1] details.doTraining = 0 details.doHoldout = 0 details.dbTableName = '' details.plotFile = '' details.showPlot = 0 details.pickleCol = -1 details.errorEstimate=0 details.sortIt=1 details.pickBase = '' details.doROC=0 details.rocThresh=-1 for arg,val in args: if arg == '-d': details.dbName = val if arg == '-t': details.dbTableName = val elif arg == '-a' or arg == '--enrich': details.activeTgt = eval(val) if(type(details.activeTgt) not in (types.TupleType,types.ListType)): details.activeTgt = (details.activeTgt,) elif arg == '--thresh': details.threshold = float(val) elif arg == '-N': details.note = val elif arg == '-p': details.persistTblName = val elif arg == '-S': details.shuffleActivities = 1 elif arg == '-H': details.doTraining = 0 details.doHoldout = 1 elif arg == '-T': details.doTraining = 1 details.doHoldout = 0 elif arg == '-F': details.filterFrac=float(val) elif arg == '-v': details.filterVal=float(val) elif arg == '--plotFile': details.plotFile = val elif arg == '--showPlot': details.showPlot=1 elif arg == '--pickleCol': details.pickleCol=int(val)-1 elif arg == '--OOB': details.errorEstimate=1 elif arg == '--noSort': details.sortIt=0 elif arg == '--doROC': details.doROC=1 elif arg == '--rocThresh': details.rocThresh=int(val) elif arg == '--pickBase': details.pickBase=val if not details.dbName or not details.dbTableName: Usage() print('*******Please provide both the -d and -t arguments') message('Building Data set\n') dataSet = DataUtils.DBToData(details.dbName,details.dbTableName, user=RDConfig.defaultDBUser, password=RDConfig.defaultDBPassword, pickleCol=details.pickleCol, pickleClass=DataStructs.ExplicitBitVect) descs = dataSet.GetVarNames() nPts = dataSet.GetNPts() message('npts: %d\n'%(nPts)) final = numpy.zeros((nPts,2),numpy.float) counts = numpy.zeros(nPts,numpy.integer) selPts = [None]*nPts models = [] if details.persistTblName: conn = DbConnect(details.dbName,details.persistTblName) message('-> Retrieving models from database') curs = conn.GetCursor() curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note)) message('-> Reconstructing models') try: blob = curs.fetchone() except: blob = None while blob: message(' Building model %d'%len(models)) blob = blob[0] try: models.append(cPickle.loads(str(blob))) except: import traceback traceback.print_exc() print('Model failed') else: message(' <-Done') try: blob = curs.fetchone() except: blob = None curs = None else: for modelName in extras: try: model = cPickle.load(open(modelName,'rb')) except: import traceback print('problems with model %s:'%modelName) traceback.print_exc() else: models.append(model) nModels = len(models) pickVects = {} halfwayPts = [1e8]*len(models) for whichModel,model in enumerate(models): tmpD = dataSet try: seed = model._randomSeed except AttributeError: pass else: DataUtils.InitRandomNumbers(seed) if details.shuffleActivities: DataUtils.RandomizeActivities(tmpD, shuffle=1) if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining): trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac, silent=1) if details.filterFrac != 0.0: trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal, details.filterFrac,-1, indicesToUse=trainIdx, indicesOnly=1) testIdx += temp trainIdx = trainFilt if details.doTraining: testIdx,trainIdx = trainIdx,testIdx else: testIdx = range(tmpD.GetNPts()) message('screening %d examples'%(len(testIdx))) nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt, indices=testIdx, errorEstimate=details.errorEstimate) message('accumulating') runningCounts = AccumulateCounts(screenRes, sortIt=details.sortIt, thresh=details.threshold) if details.pickBase: pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+') else: pickFile = None for i,entry in enumerate(runningCounts): entry = runningCounts[i] selPts[i] = entry[0] final[i][0] += entry[1] final[i][1] += entry[2] v = pickVects.get(i,[]) v.append(entry[1]) pickVects[i] = v counts[i] += 1 if pickFile: pickFile.write('%s\n'%(entry[0])) if entry[1] >= nTrueActives/2 and entry[2]1: print('#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection') else: print('#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection') i = 0 while i < nPts and counts[i] != 0: if nModels>1: mean,sd = Stats.MeanAndDev(pickVects[i]) confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90) print('%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval, final[i][1]/counts[i], counts[i],str(selPts[i]))) else: print('%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i], final[i][1]/counts[i], counts[i],str(selPts[i]))) i += 1 mean,sd = Stats.MeanAndDev(halfwayPts) print('Halfway point: %.2f(%.2f)'%(mean,sd))