2-Export Template for single q bin

The following template is meant to create the output that can be loaded by Mumott v 0.2.1.

Import packages & libraries

[3]:
%matplotlib ipympl
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import h5py
import pickle
import json
from ipywidgets import interact
from matplotlib import cm, colors
from matplotlib.patches import Rectangle

from data_processing.dataset import Dataset

Step 1 - Load args first template

Note: This includes already the q_rois chosen in the first template

[4]:
#--- User Input to load the exported args file --------------------#

filepath_args = '/data/visitors/formax/20220566/2023031408/process/work_dir_Christian/big_brain/args_big_brain.pickle'
#--- End Input to load the exported args file ---------------------#


## ---------- Don't change this if not needed ----------------- ##
with open(filepath_args,'rb') as fid:
    args = pickle.load(fid)

frame_id_range = args['frame_id_range']

for key in args.keys():
    print(key)
base_path
work_directory
integration_folder
frameID
sample_name
air_id
proposal
visit
bkg_scan
detector_name
normkey
valid_pixels
fast_scan_direction
slow_scan_direction
projection_direction
detector_angle_origin
detector_angle_pos_dir
norm_transmission
flat_field_level
correct_background
tomodata
beamline
inner_rotation_axis
inner_rotation_key
first_rotation_indexes
tilt_axis
tilt_key
data_sorting
data_index_origin
principal_rotation_right_handed
secondary_rotation_right_handed
detector_angle_0
detector_angle_right_handed
offset_positive
air_transmission
snake
phi_det
q
norm_sum
air_scattering
multiprocessing
frame_id_range
q_rois

Load beamline specific loader functions based on args input

[5]:
#do not modify
print(f"Loading beamline specific functions from {args['beamline']}_utils")
# Import the cSAXS loader functions
if args['beamline'] == 'csaxs':
    from data_processing.cSAXS_utils import metadata_reader_complete as metadata_reader
    from data_processing.cSAXS_utils import transmission_loader_mcs as transmission_loader
    from data_processing.cSAXS_utils import scattering_loader as scattering_loader_eiger
    from data_processing.cSAXS_utils import create_args

# Import the PX loader functions
elif args['beamline'] == 'px':
    from data_processing.PX_utils import metadata_reader_json as metadata_reader
    from data_processing.PX_utils import transmission_loader_eiger as transmission_loader
    from data_processing.PX_utils import scattering_loader_eiger
    from data_processing.PX_utils import create_args

# Import the PX loader functions
elif args['beamline'] == 'formax':
    from data_processing.ForMAX_utils import metadata_reader
    from data_processing.ForMAX_utils import transmission_loader
    from data_processing.ForMAX_utils import scattering_loader_eiger
    from data_processing.ForMAX_utils import create_args
Loading beamline specific functions from formax_utils

Load the dataset

Loading of data, might take a couple of minutes depending on the allocation (recommended 8 cores or more).
GPUs are a bit slower here. With 8-12 cores it takes around 1min for full dataset
[6]:
%%time
## ---------- Don't change this if not needed ----------------- ##
ds = Dataset(frame_id_range, metadata_reader, transmission_loader, scattering_loader_eiger, **args)
CPU times: user 412 ms, sys: 392 ms, total: 804 ms
Wall time: 4.23 s

Step 2 - Choose q range for reconstructions

  • The following few cells should help you picking a suitable q range

  • You might have already picked some q ranges in the script, 0-plotting, and could now choose one of them for reconstructions. They are stored in args[‘q_rois’]

  • Alternatively, you can choose a new q range manually

Avoid detector gaps in the symmetric cake plot of the data

2.1 check rois stored in arg[‘q_rois’]

[7]:
#--- User Input which projection to show --------------------#

proj_nr = 0

#--- End Input which projection to show --------------------#


fig, axs = plt.subplots(2,3, figsize=(12,8))

cake_img_avg = np.nansum(ds.stack[proj_nr].scaled_scattering_data_symmetric, axis=(0,1))

axs[1,0].loglog(ds.stack[proj_nr].metadata['q'], cake_img_avg.T)
axs[1,0].set_title('Average 1D curves for projection')
axs[1,0].set_xlabel('q / Ang-1')
axs[1,0].set_ylabel('I / a.u.')
for roi in args['q_rois']:
    axs[1,0].axvspan(roi[0], roi[1], alpha=0.3, color='grey')
    axs[1,0].text(0.5*(roi[0]+roi[1]), 1e3, roi[2], horizontalalignment='center')
    axs[1,0].set_xlabel(r'q in A$^{-1}$')


#Cake plot from scan
axs[1,1].imshow(cake_img_avg, norm=colors.LogNorm(), aspect = 'auto', extent=[np.min(args['q']), np.max(args['q']), 0, cake_img_avg.shape[0]])
axs[1,1].set_title('cake plot')
@interact(roi_nr=(0,len(args['q_rois'])-1))
def plot(roi_nr):
    # Clear axis
    for ii, array in enumerate(axs):
        for axis in array:
            axis.clear()
    #Pick roi
    roi = args['q_rois'][roi_nr]
    # load data
    f1amp, f2amp, f2phase, colorfulplot = ds.stack[proj_nr].colorful_image_plot(q_range=roi[:2], symmetric = True)
    #Redraw plot
    axs[0,0].imshow(colorfulplot, picker=True)
    axs[0,0].set_title('colorful plot')
    axs[0,1].imshow(f1amp)
    axs[0,1].set_title('symmetric intensity')
    axs[0,2].imshow(f2amp)
    axs[0,2].set_title('assymmetric intensity')
    # Redraw Scattering plot
    axs[1,0].loglog(ds.stack[proj_nr].metadata['q'], cake_img_avg.T)
    axs[1,0].set_title('Average 1D curves for projection')
    axs[1,0].set_xlabel('q / A$^{-1}$')
    axs[1,0].set_ylabel('I / a.u.')
    axs[1,0].axvspan(roi[0], roi[1], alpha=0.3, color='red')
    axs[1,0].text(0.5*(roi[0]+roi[1]), 1e3, roi[2], horizontalalignment='center')

    # update cake plot
    axs[1,1].imshow(cake_img_avg, norm=colors.LogNorm(), aspect = 'auto', extent=[np.min(args['q']), np.max(args['q']), 0, cake_img_avg.shape[0]])
    rect = Rectangle((roi[0], 0), roi[1]-roi[0], cake_img_avg.shape[0], fill = False, color = "red", linewidth = 2)
    axs[1,1].add_patch(rect)
    axs[1,1].set_title('cake plot')
    axs[1,1].set_xlabel('q / A$^{-1}$')
    axs[1,1].set_ylabel('azimuthal bin')
    axs[1,2].axis('off')

2.2 alternatively pick new rois for reconstructions

[8]:
#------------------------------------------------- User Input for Plot ------------------------------------------------#

proj_nr = 0 # Choose a projection from the steck
q_range = None#[0.1, 0.15] # Choose a q range, or use None for full q range

# Adjust range for plotting, if None autoscaling is applied
vmin = None # None is autoscaling
vmax = None # None is autoscaling

#------------------------------------------------- End Input for Plot ------------------------------------------------#
# trying to combine two cliking steps in one
hor_size = 10
vert_size = 4

if q_range is None:
    q_range = [args['q'][0], args['q'][-1]]

f1amp, f2amp, f2phase, colorfulplot = ds.stack[proj_nr].colorful_image_plot(q_range=q_range, normalized = True, symmetric = True)

#------------------------------------------------- User Input for Plot ------------------------------------------------#
# Choose which contrast, f1amp, f2amp, f2phase, colorfulplot
toplot = f1amp# colorfulplot#f1amp
#------------------------------------------------- End Input for Plot ------------------------------------------------#

# use the symmetric intensity already! (averaging opposing bins!
cake_img = ds.stack[proj_nr].scaled_scattering_data_symmetric
fig , axs = plt.subplots(1, 3, figsize=(hor_size, vert_size))

horizontal_threshold = 600 #for mouseevent.x to decide between the plots - this is a bit hacked at the moment

# Note there are nan values in the data, use nansum to avoid running into trouble with them!
axs[0].imshow(toplot, norm=colors.LogNorm(vmin = vmin, vmax = vmax), picker=True)
axs[1].set_ylabel('azimuthal bin')
axs[1].set_xlabel('q / Ang-1')
axs[1].imshow(np.nansum(cake_img,axis=(0,1)), norm=colors.LogNorm(), aspect = 'auto', extent=[np.min(args['q']), np.max(args['q']), 0, cake_img.shape[2]])
axs[2].loglog(np.nansum(np.nansum(cake_img,axis=(1,2)),axis=0).T, picker=True)
axs[2].set_xlim(10,1000)

# # Output coords
xcord_transfer = []
ycord_transfer = []
position_list = []
position_plot = []
def onpick(event):
    mouseevent = event.mouseevent
    xcord = int(mouseevent.xdata)
    ycord = int(mouseevent.ydata)

    # clicking on the left plot
    if mouseevent.x < horizontal_threshold:
        xcord_transfer.append(xcord)
        ycord_transfer.append(ycord)
        # update image plot
        axs[0].clear()
        axs[0].plot(xcord, ycord,'+', color='r', ms=12, picker = True)
        axs[0].imshow(toplot, norm=colors.LogNorm(vmin = vmin, vmax = vmax), picker=True)
        # update q-plot
        axs[2].clear()
        #axs[2].loglog(args['q'],img[ycord,xcord,...].T, picker=True)
        axs[2].loglog(cake_img[ycord, xcord,...].T, picker=True)
        for coord in position_plot:
            axs[2].axvline(coord)
            axs[1].axvline(coord)
        axs[2].set_xlim(10,1000)
        #axs[2].set_ylim(1e-1,1e6) Deactivate autoscaling in 1D scattering
        axs[2].set_xlabel(r'index q')

        axs[1].clear()
        vmin2, vmax2 = np.nanpercentile(cake_img[ycord, xcord,...], [10,90])
        axs[1].imshow(cake_img[ycord, xcord,...], norm=colors.LogNorm(vmin =vmin2, vmax = vmax2),
                      aspect = 'auto', extent=[np.min(args['q']), np.max(args['q']), 0, cake_img.shape[2]])

    # clicking on the right plot
    else:
        position_plot.append(xcord)
        axs[2].axvline(xcord)
        position_list.append(args['q'][int(xcord)])

    axs[1].set_ylabel('azimuthal bin')
    axs[1].set_xlabel('q / Ang-1')



fig.canvas.mpl_connect('pick_event', onpick)
[8]:
9

Store the new rois in list

[9]:
# ----------- do not modify this ---------------- #
# ----------- Run it only after running the cell above ---------------- #
if position_list != 0:
    position_list = list(set(position_list)) #remoce doubled positions because overlaying plots
    if not len(position_list)%2==0:
        print('Odd number of postions choosen; dropping last selected q-edge from list')
        del position_list[-1]
    position_list.sort()
    q_rois = []
    for number, i in enumerate(range(0,len(position_list),2)):
        q_rois.append((position_list[i],position_list[i+1],str(number+1)))
    print(q_rois)
[]

Step 2.3: Choose a qrange

Either pick one from the preselection args[‘q_rois’] or from the manually selected ones q_rois

[10]:
#------------------------------------------------- User Input for Loading -----------------------------------------------#

#Example 1 from args['q_rois']
# Which roi, range 1,2,3,4 and whether from args or newly selection
which_roi = 0
# preselected in 0-plotting templates
q_range = args['q_rois'][which_roi][:2]

#Example 2 from manual selection
#which_roi = 2
#q_range = q_rois[which_roi][:2]

#Example 3 just a random range
#q_range = (0.11, 0.125)

#------------------------------------------------- End Input for Loading ------------------------------------------------#

# ----------- do not modify this ---------------- #
#Printout relevant q ranges from args
print('\nSelection from args \n')
for roi in args['q_rois']:
    print(f"Range {roi[2]:s}: values are {roi[:2]}")

#Printout relevant q ranges from manual selection
print('\nManual selection \n')
try:
    for roi in q_rois:
        print(f"Range {roi[2]:s}: values are {roi[:2]}")

except:
    print('No manual selection of q_rois performed')
    pass

print(f"\nSelected qrange is {q_range}\n")


Selection from args

Range 1: values are (0.01594954951718868, 0.019637221455095562)
Range 2: values are (0.04913859695835064, 0.05430133767142027)
Range 3: values are (0.0786399724616057, 0.08822791950016359)
Range 4: values are (0.14944327366941787, 0.17230683968444055)

Manual selection


Selected qrange is (0.01594954951718868, 0.019637221455095562)

Step 3 - Load scattering data

  • Pick a number of bins that one can split the total number of bins into.

  • Note : For instance, having had 32 bins in the radial integration will lead to 16 symmetric bins which can be reduced to 8. It is recommended not going below 8 azimuthal bins

Check user input for azimuthal bins

[11]:
#------------------------------------------------- User Input for number of azimuthal bins -----------------------------------------------#

# Chose how many azimuthal bins should be used for the reconstruction
# 8 is typically our minimum, more will be more time consuming but might be needed for higher anisotropic data
n_directions = 8

#------------------------------------------------- User Input for number of azimuthal bins -----------------------------------------------#

# Check if n_direction is possible
if not args['norm_sum'].shape[0]/2%n_directions == 0:
    print("This number of segments doesn't work out.")
else:
    print(f"The current selection of n_directions {n_directions:d} works perfectly fine with shape of norm_sum {args['norm_sum'].shape}; (azimuthal bins, q bins)")


The current selection of n_directions 8 works perfectly fine with shape of norm_sum (32, 1000); (azimuthal bins, q bins)

Start data loading

Note the flag normalized in projection.bin_qranges means that data is being normalized by transmission. Multiprocessing is used to load the data

[12]:
# ----------- do not modify this ---------------- #
import multiprocessing

def load_data(projection, n_directions=8, q_range=None):
    """Simple loading function for colorful_image_plot
    Returns
    f1amp
    f2amp
    f2phase
    colorfulplot"""
    # directly return the projection
    data_q, data_norm = projection.bin_qranges(q_range, normalized = True, symmetric = True)
    phi_range = list(range(0,data_norm.shape[0]+1,data_norm.shape[0]//n_directions))
    data, _ = projection.bin_det_phi(phi_range = phi_range, data=data_q, norm=data_norm, normalized = False, symmetric=False)
    data = data.squeeze()
    # This is a bit ackward
    projection.clear_data()
    return data

def load_data_parallel(dataset, n_directions=8, array=None, q_range=None):
    """
    Parallel loading:
    Output needs to be precreated
    """
    # Create args for output
    if array is None:
        shp = (*ds.stack[0].metadata['padded_shape'], n_directions)
        array = np.zeros((shp[0], shp[1], len(ds.stack), shp[2]))
    #data = []
    args_pool = [(projection, n_directions, q_range) for projection in ds.stack]
    num_cores = multiprocessing.cpu_count()
    if num_cores > 16:
        num_cores = 16
    pool = multiprocessing.Pool(processes=num_cores)
    results = []
    for i, arg in enumerate(args_pool):
        result = pool.apply_async(load_data, args = arg)
        results.append(result)
    with tqdm(total=len(results)) as pbar:
        for i,result in enumerate(results):
            pbar.update()
            array[...,i,:] = result.get()
    pool.close()
    pool.join()

    return array

print(f'Bin into {n_directions} directions')
print(f'Using the q_range from {q_range[0]} to {q_range[1]}')

shp = (*ds.stack[0].metadata['padded_shape'], n_directions)
data_array = np.zeros((shp[0], shp[1], len(ds.stack), shp[2]))

if not args['norm_sum'].shape[0]/2%n_directions == 0:
    print("The current selection of n_directions does not work out, nothing no data loading is performed!! - please correct n_direction in the cell above")
else:
    if 'multiprocessing' in args:
        if args['multiprocessing']:
            data_array = load_data_parallel(ds, n_directions=n_directions, array = data_array, q_range=q_range)#f1amp, f2amp, colorfulplot
        else:
            for ii, projection in tqdm(enumerate(ds.stack)):
                data_array[:,:,ii,:] = load_data(projection, n_directions=n_directions, q_range=q_range)
    else:
        for ii, projection in tqdm(enumerate(ds.stack)):
            data_array[:,:,ii,:] = load_data(projection, n_directions=n_directions, q_range=q_range)



    #Create detector angle vector from first projection
    _, data_norm = ds.stack[0].bin_qranges(q_range, normalized = True, symmetric = True)
    max_angle = data_norm.shape[0]
    stepping = data_norm.shape[0]//n_directions
    angles = (args['phi_det'][:max_angle:stepping] +
              args['phi_det'][stepping//2:max_angle:stepping])/2

Bin into 8 directions
Using the q_range from 0.01594954951718868 to 0.019637221455095562
100%|██████████| 419/419 [03:38<00:00,  1.92it/s]

Step 4 Import alignment and export data for Mumott

Step 4.1 Import mask and absorption tomogram

[13]:
# ----------- do not modify this ---------------- #
# Load mask

with open(args['work_directory'] + '/' + args['sample_name'] + '/mask_arrays.npy','rb') as fid:
    masks = pickle.load(fid)

# Load absorption tomo
with open(args['work_directory'] + '/' + args['sample_name'] + '/tomogram.npy','rb') as fid:
    tomogram = pickle.load(fid)

# plt.figure()
# plt.imshow(masks[:,:,proj_nr])
# plt.show()

Step 4.2 Import alignment shifts

[14]:
# ----------- do not modify this ---------------- #
# Load aligment values

with open(args['work_directory'] + '/' + args['sample_name'] + '/shift_values.json','r') as fid:
    shifts = json.load(fid)
plt.figure()
plt.plot(frame_id_range, shifts['offsets_j'], '.-')
plt.plot(frame_id_range, shifts['offsets_k'], '.-')
plt.legend(['offsets_j', 'offsets_k'])
plt.ylabel('pixels')
plt.xlabel('frameid')
plt.show()

Step 4.3 Last check before data is written to file for Mumott

  • Important, also check that there are no Nans in the data. Mumott can not deal with them, they should be set to 0. If it has NaNs, there is a section in the next script to correct for that

  • There will be a red printout in case your data or mask has Nan values

  • It also plots you once again the raw data, mask and masked data one last time before the export

  • Note: The histograms take long to load and make the scrolling laggy… deactivating them lets it run more smoothly

[15]:
# Choose option to show histograms
show_histograms = False #True

# ----------- do not modify this ---------------- #

fig, axs = plt.subplots(2,3, figsize=(10,8))

if np.isnan(data_array).any():
    #print(f"The data has {len(np.isnan(data_array)==True)} NaN values")
    print(f"\x1b[31m\The data array has Nan values, this needs to be accounted for in the next template --> Mumott!!! Please run the step that mentions NaN values to avoid crashing reconstructions\x1b[0m")
elif np.isnan(masks).any():
    #print(f"The mask has {len(np.isnan(masks)==True)} NaN values")
    print(f"\x1b[31m\The mask has Nan values, this needs to be accounted for in the next template --> Mumott!!! Please run the step that mentions NaN values to avoid crashing reconstructions\x1b[0m")


@interact(proj_nr=(0,data_array.shape[2]-1))
def plot(proj_nr):
    for array in axs:
        for axis in array:
            axis.clear()
    axs[0, 0].imshow(np.nansum(data_array[:,:,proj_nr,:], axis=-1))
    axs[0, 0].set_title(f"Sum over all azimuthals, proj {proj_nr:03d}")
    axs[0, 1].imshow(np.nansum(data_array[:,:,proj_nr,:], axis=-1)*masks[:,:,proj_nr])
    axs[0, 1].set_title('Sum times mask')
    axs[0, 2].imshow(masks[:,:,proj_nr],vmin = 0, vmax = 1)
    axs[0, 2].set_title('Mask')
    axs[1, 0].imshow(ds.stack[proj_nr].transmission_factor, cmap =  'Greys_r')
    axs[1, 0].set_title('normalized transmission')
    if show_histograms:
        axs[1, 1].hist(ds.stack[proj_nr].transmission_factor.flatten(), 500)
        axs[1, 1].set_title('histogram transmission')
        axs[1, 2].hist(np.nansum(data_array[:,:,proj_nr,:], axis=-1).flatten(), 500)
        axs[1, 2].set_yscale('log')
        axs[1, 2].set_title('histogram scattering')

Step 4.4 Output data for Mumott

[16]:
 # ----------- do not modify this ---------------- #
image_shape = ds.stack[0].metadata['padded_shape']
volume_shape = tomogram.shape

fname = f"{args['work_directory']}/{args['sample_name']}/dataset_q_{q_range[0]:.3f}_{q_range[1]:.3f}.h5"
with h5py.File(fname,'w') as file:

    #Assign global parameters
    file.create_dataset('detector_angles', data = np.array(angles)*np.pi/180)
    file.create_dataset('volume_shape', data = volume_shape)

    # Make data group
    grp = file.create_group('projections')

    # Loop through projections
    for ii, projection in tqdm(enumerate(ds.stack)):

        # Make a group for each projection
        subgrp = grp.create_group(str(ii))

        #Choose here between using reconstructed transmission or raw data
        diode = projection.transmission_factor

        # Mask of values where data is nan!
        mask = ~masks[:,:,ii]
        # This output means that data has already been normalized by the transmission signal from the diode
        data = data_array[:,:,ii,:]


        #weights = projection.needle_mask()[args['valid_pixels']]
        weights = masks[...,ii].astype(float)
        # Make sure to mask the parts where data was nan
        weights[mask] = 0

        offset_j = shifts['offsets_j'][ii]
        offset_k = shifts['offsets_k'][ii]

        rotations = projection.metadata[args['inner_rotation_key']]
        tilts = projection.metadata[args['tilt_key']]

        # Assign data

        #
        subgrp.create_dataset('data', data = np.ascontiguousarray(data))
        subgrp.create_dataset('diode', data = diode)
        subgrp.create_dataset('weights', data = weights)
        subgrp.create_dataset('offset_j', data = np.array([offset_j]))
        subgrp.create_dataset('offset_k', data = np.array([offset_k]))
        subgrp.create_dataset('rotations', data = np.array([rotations*np.pi/180]))
        subgrp.create_dataset('tilts', data = np.array([tilts*np.pi/180]))

print(f"Data exported to file {fname}")
419it [00:01, 403.79it/s]
Data exported to file /data/visitors/formax/20220566/2023031408/process/work_dir_Christian/big_brain/dataset_q_0.016_0.020.h5

[ ]: