xboa
HitTest.py
Go to the documentation of this file.
1 import sys
2 import xboa.core.Hitcore
3 from xboa import *
4 from xboa.core import *
5 from xboa.Hit import Hit
6 from xboa.Bunch import Bunch
7 from xboa.Common import rg
8 
10 run_test = xboa.test.TestTools.run_test
11 run_test_group = xboa.test.TestTools.run_test_group
12 __float_tol = xboa.test.TestTools.__float_tol
13 parse_tests = xboa.test.TestTools.parse_tests
14 __test_root_hist = xboa.test.TestTools.test_root_hist
15 __test_root_canvas = xboa.test.TestTools.test_root_canvas
16 __test_root_graph = xboa.test.TestTools.test_root_graph
17 
18 import os
19 import StringIO
20 import sys
21 import time
22 import math
23 import string
24 if Common.has_numpy():
25  import numpy
26  from numpy import linalg
27 import operator
28 import bisect
29 
30 """
31 Test script:
32 One test function for each app function; one test function for each module. Name of each function test is blah_test(...)
33 Helper functions are private.
34 Return value 'fail'; 'warning'; 'pass'
35 Module test function is called test_module
36 So:
37  to run Hit tests call test_hit()
38 """
39 __float_tol = Common.float_tolerance
40 
41 import copy
42 
44  try:
45  a = Hitcore.Hitcore()
46  b = a
47  assert sys.getrefcount(a) == 3
48  del b
49  assert sys.getrefcount(a) == 2
50  assert Hitcore.Hitcore.integrity_test() == None
51  assert Hitcore.Hitcore.get_variables() == ['x', 'y', 'z', 't', 'px', 'py',
52  'pz', 'energy', 'mass', 'local_weight', 'global_weight', 'weight',
53  'bx', 'by', 'bz', 'ex', 'ey', 'ez', 'sx', 'sy', 'sz', 'path_length',
54  'proper_time', 'e_dep', 'charge', 'event_number', 'station', 'pid',
55  'status', 'spill', 'particle_number', 'eventNumber',
56  'particleNumber']
57  assert Hitcore.Hitcore.set_variables() == ['x', 'y', 'z', 't', 'px', 'py',
58  'pz', 'energy', 'mass', 'local_weight', 'global_weight', 'bx', 'by',
59  'bz', 'ex', 'ey', 'ez', 'sx', 'sy', 'sz', 'path_length',
60  'proper_time', 'e_dep', 'charge', 'event_number', 'station', 'pid',
61  'status', 'spill', 'particle_number', 'eventNumber',
62  'particleNumber']
63 
64  z_target = 112.0
65  a.set('z', z_target)
66  try:
67  a.set('bilbo', 114.0)
68  raise KeyError('Did not get expected exception for bad set variable')
69  except:
70  pass
71  if( abs(a.get('z') - z_target) > 1e-9): raise KeyError('get or set failed')
72  try:
73  a.get('bilbo')
74  raise KeyError('Did not get expected exception for bad get variable')
75  except:
76  pass
77  return 'pass'
78  except:
79  sys.excepthook(sys.exc_info()[0],sys.exc_info()[1],sys.exc_info()[2])
80  return 'fail'
81 
83  hit = Hit.new_from_dict({"spill":1, "event_number":2, "particle_number":3})
84  if (len(Hitcore.Hitcore.get_weight_dict()) != 0): return 'fail'
85  if (abs(hit.get("global_weight")-1.) > 1e-9): return 'fail'
86  if (len(Hitcore.Hitcore.get_weight_dict()) != 0): return 'fail'
87  hit.set('global_weight', 0.5)
88  if (abs(hit.get_global_weight()-0.5) > 1e-9): return 'fail'
89  if (len(Hitcore.Hitcore.get_weight_dict()) != 1): return 'fail'
90  if (Hitcore.Hitcore.get_weight_dict().keys() != [(1, 2, 3)]): return 'fail'
91  Hitcore.Hitcore.clear_global_weights()
92  if (len(Hitcore.Hitcore.get_weight_dict()) != 0): return 'fail'
93  return 'pass'
94 
95 def hit_equality_test(hit1, hit2, isEqual):
96  if hit1 == hit2 and isEqual: return 'pass'
97  if hit1 != hit2 and not isEqual: return 'pass'
98  return 'fail'
99 
100 def hit_repr_test(hit):
101  new_hit = eval(repr(hit))
102  if new_hit == hit: return 'pass'
103  return 'fail'
104 
105 def hit_copy_test(hit):
106  hit_copy = copy.copy(hit)
107  hit_copy2 = hit.copy()
108  if hit_copy is hit and hit_copy2 is hit: return 'pass' #test for address
109  return 'fail'
110 
112  hit_deep_copy = copy.deepcopy(hit)
113  hit_deep_copy2 = hit.deepcopy()
114  if hit_deep_copy == hit and not hit_deep_copy is hit and hit_deep_copy2 == hit and not hit_deep_copy2 is hit:
115  return 'pass'
116  return 'fail'
117 
119  vec = hit.get_vector(['x','y'])
120  test_pass = vec[0,0] == hit.get('x') and vec[0,1] == hit.get('y')
121  vec = hit.get_vector(['x','y'], {'x':10.,'z':100.})
122  test_pass = test_pass and vec[0,0] == hit.get('x')-10. and vec[0,1] == hit.get('y')
123  if test_pass: return 'pass'
124  return 'fail'
125 
127  hit1 = hit.deepcopy()
128  hit1.translate({'px':10., 'py':5.}, '')
129  test_pass = abs(hit1.get('px')-hit.get('px')-10.) < __float_tol and abs(hit1.get('py')-hit.get('py')-5.) < __float_tol
130  hit1 = hit.deepcopy()
131  hit1.translate({'px':10., 'py':5.}, 'energy')
132  test_pass = test_pass and abs(hit1.get('px')-hit.get('px')-10.) < __float_tol and abs(hit1.get('py')-hit.get('py')-5.) < __float_tol and hit1.check()
133  if test_pass: return 'pass'
134  return 'fail'
135 
137  hit1 = hit.deepcopy()
138  hit1.abelian_transformation(['x','px'], numpy.matrix([[1,0],[0,1]]), {'x':0,'px':0}, {'x':0,'px':0}, '')
139  test_pass = hit == hit1
140  hit1.abelian_transformation(['x','px'], numpy.matrix([[1,0],[0,1]]))
141  test_pass = test_pass and hit == hit1
142  R = numpy.matrix([[0.76,0.43],[0.76,0.29]])
143  O = numpy.matrix([[1.3], [-1.7]])
144  T = numpy.matrix([[1.782],[2.35]])
145  hit1.abelian_transformation(['x','px'], R, {'x':T[0,0],'px':T[1,0]}, {'x':O[0,0],'px':O[1,0]}, 'energy')
146  vec_out = R*((hit.get_vector(['x','px']).transpose()) - O) + T + O
147  test_pass = test_pass and abs(vec_out[0,0] - hit1.get('x')) < __float_tol and abs(vec_out[1,0] - hit1.get('px')) < __float_tol and hit1.check()
148  if test_pass: return 'pass'
149  return 'fail'
150 
152  test_pass = True
153  hit1 = hit.deepcopy()
154  if hit1.get('energy') < hit1.get('p') or hit1.get('energy') < hit1.get('mass'): return 'pass'
155  for key in hit.mass_shell_variables():
156  hit1.set(key, 1111.)
157  hit1.mass_shell_condition(key)
158  test_pass = test_pass and hit1.check()
159  for key2 in ['px','py','pz','energy','mass']:
160  test_pass = test_pass and abs( abs(hit1.get(key2)) - abs(hit.get(key2)) ) < __float_tol
161  if test_pass: return 'pass'
162  return 'fail'
163 
164 def hit_get_test(hit):
165  test_pass = True
166  for key in hit.get_variables():
167  try:
168  value = hit.get(key)
169  good = True
170  if type(key) == int:
171  good = abs(hit._Hit__dynamic[key]-value) < __float_tol
172  else:
173  if key in Hitcore.Hitcore.get_variables():
174  good = abs(hit.get(key) - hit._Hit__hitcore.get(key)) < __float_tol
175  elif key == 'p': good = abs(math.sqrt(hit.get('px')**2+hit.get('py')**2+hit.get('pz')**2) - value) < __float_tol
176  elif key == 'r': good = abs(math.sqrt(hit.get('x')**2+hit.get('y')**2) - value) < __float_tol
177  elif key == 'phi': good = abs(math.atan(hit.get('y')/hit.get('x')) - value) < __float_tol \
178  or ( abs(hit.get('x')) < __float_tol and abs(hit.get('y'))/hit.get('y')*(2.*value/math.pi)-1 < __float_tol)
179  elif key == 'pt': good = abs(math.sqrt(hit.get('px')**2+hit.get('py')**2) - value) < __float_tol
180  elif key == 'pphi': good = abs(math.atan(hit.get('py')/hit.get('px')) - value) < __float_tol \
181  or ( abs(hit.get('px')) < __float_tol and abs(hit.get('py'))/hit.get('py')*(2.*value/math.pi)-1 < __float_tol)
182  elif key == 'x\'': good = abs(hit.get('px')/hit.get('pz') - value) < __float_tol
183  elif key == 'y\'': good = abs(hit.get('py')/hit.get('pz') - value) < __float_tol
184  elif key == 't\'': good = abs(-hit.get('energy')/hit.get('pz') - value) < __float_tol
185  elif key == 'ct\'': good = abs(-hit.get('energy')/hit.get('pz') - value) < __float_tol
186  elif key == 'r\'': good = abs(hit.get('pt')/hit.get('pz') - value) < __float_tol
187  elif key == 'spin': good = abs(math.sqrt(hit.get('sx')**2+hit.get('sy')**2+hit.get('sz')**2) - value) < __float_tol
188  elif key == 'ct': good = abs(hit.get('t')*Common.constants['c_light'] - value) < __float_tol
189  elif key == '': good = value == None
190  elif key == 'weight': good = abs(hit.get('local_weight')*hit.get('global_weight') - value) < __float_tol
191  elif key == 'z\'': good = abs(1. - value) < __float_tol #dz/dz is always 1!
192  elif key == 'r_squared': good = abs(hit['x']*hit['x']+hit['y']*hit['y'] - value) < __float_tol
193  elif key == 'l_kin': good = abs(hit['x']*hit['py']-hit['y']*hit['px'] - value) < __float_tol
194  elif key == 'kinetic_energy': good = abs(hit['energy']-hit['mass'] - value) < __float_tol
195  elif key == 'global_weight':
196  if hit.get('eventNumber') in Hit._Hit__global_weights_dict: good = abs(Hit._Hit__global_weights_dict[ hit.get('eventNumber') ] - value) < __float_tol
197  else: good = abs(1. - value) < __float_tol
198  else:
199  print 'warning: key ',key,' not tested'
200  if not good: print 'Get test failed with',key,value
201  except ZeroDivisionError:
202  good = True
203  test_pass = test_pass and good
204  if test_pass: return 'pass'
205  return 'fail'
206 
207 def hit_set_test(hit):
208  test_pass = True
209  for key in hit.set_variables():
210  try:
211  hit1 = hit.deepcopy()
212  if key == '':
213  test_pass = test_pass and hit1 == hit
214  elif type(hit1.get(key)) == type(1.):
215  hit1.set(key, 1000.)
216  test_pass = test_pass and abs(hit1.get(key) - 1000.) < __float_tol
217  elif type(hit1.get(key) == type(1)):
218  hit1.set(key, 211)
219  test_pass = test_pass and (hit1.get(key) == 211)
220  elif type(hit1.get(key)) == type('string'):
221  hit1.set(key, 'some_string')
222  test_pass = test_pass and hit1.get(key) == 'some_string'
223  if not test_pass:
224  print 'Set test failed with key \''+str(key)+'\'',hit1.get(key), type(hit1.get(key))
225  return 'fail'
226  except:
227  pass
228  if test_pass: return 'pass'
229  else: return 'fail'
230 
231 def hit_check_test(hit):
232  hit1 = hit.deepcopy()
233  hit1.set('pid', -13)
234  hit1.set('mass', Common.pdg_pid_to_mass[abs(hit1.get('pid') )])
235  hit1.mass_shell_condition('energy')
236  test_pass = hit1.check()
237  hit1.set('pid', -11)
238  test_pass = test_pass and not hit1.check()
239  hit1.set('pid', -13)
240  hit1.set('mass', Common.pdg_pid_to_mass[11])
241  test_pass = test_pass and not hit1.check()
242  hit1.set('pid', 11)
243  test_pass = test_pass and not hit1.check()
244  hit1.mass_shell_condition('energy')
245  test_pass = test_pass and hit1.check()
246  hit1.set('pid', int(2e6))
247  hit1.set('mass', 0.)
248  hit1.mass_shell_condition('energy')
249  test_pass = test_pass and not hit1.check()
250  pid = hit1.get('pid')
251  Hit.set_bad_pids([int(2e6)])
252  test_pass = test_pass and hit1.check()
253  Hit.set_bad_pids([])
254  if test_pass: return 'pass'
255  return 'fail'
256 
258  hit.set('global_weight', 1000.)
259  Hit.clear_global_weights()
260  test_pass = abs(hit.get('global_weight') - 1.) < __float_tol
261  if test_pass: return 'pass'
262  return 'fail'
263 
264 #This function tests:
265 #read_builtin_formatted(self, format, filehandle)
266 #new_from_read_builtin(format, filehandle)
267 #write_builtin_formatted(self, format, file_handle)
268 #
269 #group io operations will have to be tested in Bunch
270 #open_filehandle_for_writing
271 #write_list_builtin_formatted
273  test_pass_all = True
274  for key in Hit.file_types():
275  test_pass = True
276  filehandle = open('out_test', 'w')
277  if not filehandle: return 'warning - could not open file out_test'
278  try:
279  hit.write_builtin_formatted(key, filehandle)
280  if key.find('maus') > -1: test_pass = False
281  except IOError:
282  if key.find('maus') == -1 and key.find('muon1_csv') == -1:
283  test_pass = False
284  filehandle.close()
285 
286  filehandle = open('out_test', 'r')
287  if not filehandle: return 'warning - could not open file out_test'
288  hit1 = Hit()
289  try:
290  hit1.read_builtin_formatted(key, filehandle)
291  if key.find('icool_for003') or key.find('mars'): hit1['station'] = hit['station']
292  if key.find('maus') > -1: test_pass = False
293  test_pass = test_pass and hit1 == hit
294  except IOError:
295  if key.find('maus') == -1: test_pass = False
296  except EOFError:
297  if key.find('muon1_csv') == -1: test_pass = False # we test muon1_csv at bunch level
298  filehandle.close()
299 
300  os.remove('out_test')
301  if test_pass == False:
302  print 'Failed on builtin format', key
303  test_pass_all = False
304  if test_pass_all: return 'pass'
305  return 'fail'
306 
307 
309  test_pass = True
310  format_list = _Hit__file_formats['icool_for009']
311  format_units_dict = _Hit__file_units['icool_for009']
312  fh = open('out_test','w')
313  if not filehandle: return 'warning - could not open file out_test'
314  hit.write_user_formatted(format_list, format_units_dict, fh, separator=' ')
315  fh.close()
316 
317  fh = open('out_test','r')
318  if not filehandle: return 'warning - could not open file out_test'
319  hit1 = Hit()
320  hit1.read_user_formatted(format_list, format_units_dict, fh)
321  test_pass = test_pass and hit1 == hit
322  fh.close()
323 
324  fh = open('out_test','r')
325  if not filehandle: return 'warning - could not open file out_test'
326  hit1 = Hit()
327  hit1 = new_from_read_user(format_list, format_units_dict, fh)
328  test_pass = test_pass and hit1 == hit
329  fh.close()
330 
331  if test_pass: return 'pass'
332  return 'fail'
333 
334 def hit_new_from_maus_object_test(hit): # tests generation of dict also
335  (maus_dict, event_number) = hit.get_maus_dict('maus_virtual_hit')
336  new_hit = Hit.new_from_maus_object('maus_virtual_hit', maus_dict, event_number)
337  hit_cp = hit.deepcopy()
338  if new_hit == hit_cp: return 'pass'
339  return 'fail'
340 
342  hit.set_g4bl_unit('cm')
343  filehandle = open('set_g4bl_unit_test', 'w')
344  if not filehandle: return 'warning - could not open file out_test'
345  hit.write_builtin_formatted('g4beamline_bl_track_file', filehandle)
346  filehandle.close()
347  filehandle = open('set_g4bl_unit_test', 'r')
348  hit.set_g4bl_unit('mm')
349  hit1 = Hit.new_from_read_builtin('g4beamline_bl_track_file',filehandle)
350  os.remove('set_g4bl_unit_test')
351  for i in ['x','y','z']:
352  if hit[i]/Common.units['cm'] != hit1[i]/Common.units['mm']:
353  print 'Failed set_g4bl_unit_test',hit[i],hit1[i]
354  return 'fail'
355  return 'pass'
356 
358  some_hit = Hit()
359  bad_pid = int(1e6)
360  Hit.set_bad_pids([]) # clear any old badness
361  filehandle = StringIO.StringIO('0 12 0 0 0 0 1 1 '+str(bad_pid)+' 938.272013 792.9087832 -271.9208414 -7160 122.5049215 -118.0579616 90.15054592 2811.993882 2968.118725 11.35922989 2.368198721e-05 2.19331651e-05 0.001529992003 0 0 0 0 0 1005.281896 -288.902357 -12500 18.78433052 -77.5010436 -26.9017406 2826.69496 2 10.68162998 CoilsUS11 5344.880819 5.915402589\n')
362  some_hit.read_builtin_formatted('g4mice_special_hit', filehandle)
363  if not Hit.get_bad_pids() == [bad_pid]:
364  print 'read bad pid list',Hit.get_bad_pids(),[bad_pid]
365  return 'fail'
366  bad_pid_list = [-13, 0, 15, 17]
367  Hit.set_bad_pids(bad_pid_list)
368  if not Hit.get_bad_pids() == bad_pid_list:
369  print 'set_bad_pids pid list',Hit.get_bad_pids(),bad_pid_list
370  return 'fail'
371  return 'pass'
372 
373 def hit_get_maus_tree_test(hit_list): # also test get_list_of_maus_dicts
374  test_pass = True
375  for name in ['maus_virtual_hit']:
376  maus_tree = Hit.get_maus_tree(hit_list, name)
377  ev_dict = {}
378  for hit in hit_list: ev_dict[hit['event_number']] = True
379  test_pass = test_pass and len(maus_tree) == len(ev_dict.keys())
380  test_pass = test_pass and len(maus_tree[0]["mc_events"][0]["virtual_hits"]) > 0
381  if not test_pass:
382  print json.dumps(maus_tree, indent=2)
383  maus_dict_list = []
384  for spill in maus_dict_list:
385  maus_dict_list += Bunch.get_list_of_maus_dicts(name, spill)
386  for index, maus_dict in enumerate(maus_dict_list):
387  test_pass = test_pass and Hit.new_from_maus_dict(maus_dict) == hit_list[index]
388  test_pass = Bunch.get_list_of_maus_dicts(name, {}) == []
389  if test_pass: return 'pass'
390  return 'fail'
391 
393  hit_list1 = copy.deepcopy(hit_list)
394  hit_list1 = Hit.force_unique_particle_number(hit_list1)
395  p_num_list = []
396  for hit in hit_list1:
397  if hit['particle_number'] in p_num_list:
398  return 'fail'
399  p_num_list.append(hit['particle_number'])
400  return 'pass'
401 
402 def hit_test(hit): #test a hit - hit should be physical otherwise tests will give false negatives (e.g. no negative energy)
403  test_results = []
404  tests = [hit_repr_test, hit_copy_test, hit_deep_copy_test, hit_get_vector_test, hit_translate_test, hit_abelian_transformation_test,
405  hit_mass_shell_condition_test, hit_get_test, hit_set_test, hit_check_test, hit_set_g4bl_unit_test, hit_io_builtin_formatted_test,
406  hit_new_from_maus_object_test, hit_clear_global_weights_test, hit_bad_pids_test]
407  run_test_group(test_results, tests, [(hit,)]*len(tests))
408  return parse_tests(test_results)
409 
410 def test_hit():
411  test_results = []
412  (passes, fails, warns) = (0,0,0)
413  run_test(test_results, hitcore_test, ())
414  run_test(test_results, hit_global_weight_test, ())
415  hit = Hit.new_from_dict({'x':1.,'y':2.,'z':10.,'t':1.,'px':3.,'py':10.,'pz':200.,'pid':13,'mass':Common.pdg_pid_to_mass[13], 'charge':-1}, 'energy')
416  hit_list = [hit]
417  hit_list.append(Hit.new_from_dict({'x':0.,'y':0.,'z':0.,'t':0.,'px':0.,'py':0.,'pz':0.,'pid':13,'mass':Common.pdg_pid_to_mass[13], 'station':1, 'charge':-1}, 'energy'))
418  hit_list.append(Hit.new_from_dict({'x':0.,'y':0.,'z':0.,'t':0.,'px':0.,'py':0.,'pz':-200.,'pid':13,'mass':Common.pdg_pid_to_mass[13], 'station':2, 'charge':-1}, 'energy'))
419 
420  hit1 = Hit.new_from_dict({'x':1.,'y':2.,'z':10.,'t':1.,'px':3.,'py':10.,'pz':200.,'pid':13,'mass':Common.pdg_pid_to_mass[13], 'charge':-1}, 'energy')
421  hit2 = Hit.new_from_dict({'x':1.,'y':2.,'z':10.,'t':1.,'px':3.,'py':10.,'pz':200.,'pid':13,'mass':Common.pdg_pid_to_mass[13],'station':1, 'charge':-1}, 'energy')
422  junk = 0
423 
424  for key in hit_list:
425  (apass, afail, awarn) = hit_test(key)
426  passes += apass
427  fails += afail
428  warns += awarn
429 
430  run_test_group(test_results, [hit_get_maus_tree_test, hit_force_unique_particle_number_test], [(hit_list,), (hit_list,)])
431 
432  args = [(hit, hit, True), (hit, junk, False), (hit, hit1, True), (hit, hit2, False)]
433  run_test_group(test_results, [hit_equality_test]*4, args)
434 
435  (passesEq, failsEq, warnsEq) = parse_tests(test_results)
436  passes += passesEq
437  fails += failsEq
438  warns += warnsEq
439  print '\n============\n|| HIT ||\n============'
440  print 'Passed ',passes,' tests\nFailed ',fails,' tests\n',warns,' warnings\n\n\n'
441  return (passes,fails,warns)
442 
443 if __name__ == "__main__":
444  test_hit()
445 
446