from libtbx.test_utils import approx_equal
import iotbx.ccp4_map
from cctbx import miller, crystal
from cctbx.array_family import flex

def run(map_file_name = "emd_5169.map",
        d_min         = 8.4,
        k_blur        = 1,
        b_blur        = 100, 
        output_file_name = "map_data.mtz"):
  # read in map
  m = iotbx.ccp4_map.map_reader(file_name=map_file_name)
  print "Input map information:"
  m.show_summary(prefix="  ")
  # generate complete set of Miller indices up to given high resolution d_min
  cs = crystal.symmetry(m.unit_cell_parameters, m.space_group_number)
  complete_set = miller.build_set(
    crystal_symmetry = cs, 
    anomalous_flag   = False, 
    d_min            = d_min)
  print "Complete set information:"
  complete_set.show_comprehensive_summary(prefix="  ")
  # compute Fobs from map (note: Fobs is a complex array)
  f_obs_cmpl = complete_set.structure_factors_from_map(
    map            = m.data.as_double(),
    use_scale      = True, 
    anomalous_flag = False, 
    use_sg         = True)
  mtz_dataset = f_obs_cmpl.as_mtz_dataset(column_root_label="Fobs_cmpl")
  mtz_dataset.add_miller_array(
    miller_array      = abs(f_obs_cmpl), 
    column_root_label = "Fobs")
  # convert phases into HL coefficeints
  f_model_phases = f_obs_cmpl.phases().data()
  sin_f_model_phases = flex.sin(f_model_phases)
  cos_f_model_phases = flex.cos(f_model_phases)
  ss = 1./flex.pow2(f_obs_cmpl.d_spacings().data()) / 4.
  t = 2*k_blur * flex.exp(-b_blur*ss)
  hl_a_model = t * cos_f_model_phases
  hl_b_model = t * sin_f_model_phases
  hl_data = flex.hendrickson_lattman(a = hl_a_model, b = hl_b_model)
  hl = f_obs_cmpl.customized_copy(data = hl_data)
  mtz_dataset.add_miller_array(
    miller_array      = hl, 
    column_root_label = "HL")
  # write output MTZ file with all the data
  mtz_object = mtz_dataset.mtz_object()
  mtz_object.write(file_name = output_file_name)
  # sanity check
  map_coeffs = abs(f_obs_cmpl).phase_transfer(phase_source = hl)
  assert approx_equal(map_coeffs.map_correlation(other=f_obs_cmpl), 1)

if (__name__ == "__main__"):
  run()
