Source code for autosort_neuron.waveform_loader
import numpy as np
import torch
from torch.utils import data
import pickle
[docs]
def location_cal(sensor_positions, batch_features):
NumChannels = batch_features.shape[1]
location_day = []
b_max = batch_features.max(-1)
b_min = batch_features.min(-1)
amplitudes = b_max-b_min
# amplitudes_multi = np.multiply(amplitudes,amplitudes)
# amplitudes = np.multiply(amplitudes_multi,amplitudes)
amplitudes =np.square(amplitudes)
amplitudes = np.square(amplitudes)
sum_square_amplitute=np.sum(amplitudes,axis=1)
location_day=[]
for ij in range(sensor_positions.shape[1]):
x=np.dot(sensor_positions[:, ij] , amplitudes.T)
x=np.divide(x, sum_square_amplitute)
location_day.append(x)
# y=np.dot(sensor_positions[:, 1] , amplitudes.T)
# y=np.divide(y, sum_square_amplitute)
#
# # location_day = [x, y]
# z=np.dot(sensor_positions[:, 2] , amplitudes.T)
# z=np.divide(z, sum_square_amplitute)
# location_day=[x,y,z]
location_day=np.array(location_day).T
return location_day
[docs]
def location_cal_group(sensor_positions, batch_features,group_id):
group_batch = sensor_positions[:,-1]
location_day=np.zeros((batch_features.shape[0],3))
for i in np.unique(group_batch):
care_loc = np.where(group_batch==i)[0]
look_spike_loc = np.nonzero(np.in1d(group_id, care_loc))[0]
location_day_batch = location_cal(sensor_positions[care_loc,:], batch_features[look_spike_loc,:,:][:,care_loc,:])
location_day[look_spike_loc,:] = location_day_batch
return location_day
[docs]
class waveformLoader(data.Dataset):
def __init__(self, root, shank_channel ,
sensor_positions,Keep_id=None):
with (open(root + "X_waveform.pkl", "rb")) as openfile:
datafile = pickle.load(openfile)
try:
with (open(root + "Y_spike_id.pkl", "rb")) as openfile:
GT = pickle.load(openfile)
except FileNotFoundError:
GT = np.zeros(datafile.shape[0])-1
with (open(root + "Y_spike_id_noise.pkl", "rb")) as openfile:
channel_id = np.array(pickle.load(openfile))
if Keep_id is None:
Keep_id = np.unique(GT)
Keep_id = list(Keep_id[Keep_id != -1])
self.keep_id = Keep_id
mask = ~np.isin(GT, Keep_id)
GT = np.array(GT)
GT_binary = np.zeros((GT.shape[0], 2))
GT_binary[list(mask), 0] = 1
GT_binary[~mask, 1] = 1
self.GT_unique = Keep_id + [-1]
self.GT_binary = GT_binary
self.Img_single = datafile[np.arange(datafile.shape[0]), np.array(channel_id).astype('int'), :]
self.GT_LIST = GT
GT_array = np.zeros((len(GT), len(Keep_id)))
for idx, unique_id in enumerate(Keep_id):
rmv_list = np.where(np.array(GT) == unique_id)[0]
GT_array[rmv_list, idx] = 1
self.GT = GT_array
self.Img = datafile
self.pos_weight_noise =torch.tensor( [-np.sum(self.GT_binary[:,0]-1)/np.sum(self.GT_binary[:,0]),
-np.sum(self.GT_binary[:,1]-1)/np.sum(self.GT_binary[:,1])])
self.pos_weight_label = torch.tensor([-(np.sum(self.GT[:,i]-1)+sum(np.sum(GT_array,axis=1)==0))/np.sum(self.GT[:,i]) for i in range(self.GT.shape[1])])
pred_location = location_cal_group(sensor_positions, datafile, channel_id)
self.pred_location = pred_location
print('pred_location',pred_location.shape)
self.n_classes = len(set(self.GT_unique))
def __len__(self):
return len(self.GT)
def __getitem__(self, index):
return self.Img[index, ...] , self.GT[index, ...], self.GT_binary[index, ...], self.Img_single[index, ...],self.pred_location[index,...]