Source code for autosort_neuron.model

import torch.nn.functional as nnf
import torch
import torch.nn as nn
import numpy as np


[docs] def load_day_length(path): day_length = np.load(path) day_id = [] for i in range(day_length.shape[0]): if i == 0: day_id.append(day_length[i]) else: append_i = day_id[i - 1] + day_length[i] day_id.append(append_i) day_id.insert(0, 0) return day_id
[docs] class clssimp(nn.Module): def __init__(self, ch=2880, num_classes=20): super(clssimp, self).__init__() self.pool = nn.AdaptiveAvgPool1d(output_size=(ch)) self.way1 = nn.Sequential( nn.Linear(ch, 1000, bias=True), nn.BatchNorm1d(1000), nn.ReLU(inplace=True), ) self.way2 = nn.Sequential( nn.Linear(1000, 512, bias=True), nn.BatchNorm1d(512), nn.ReLU(inplace=True), ) self.way3 = nn.Sequential( nn.Linear(512, 100, bias=True), nn.BatchNorm1d(100), nn.ReLU(inplace=True), ) self.cls = nn.Linear(100, num_classes, bias=True)
[docs] def forward(self, x): # bp() x = self.pool(x[None, :]) x = x.reshape(x.size(1), -1) x = self.way1(x) x = self.way2(x) x = self.way3(x) logits = self.cls(x) return logits
[docs] def intermediate_forward(self, x): x = self.pool(x[None, :]) x = x.reshape(x.size(1), -1) x = self.way1(x) x = self.way2(x) x = self.way3(x) return x
[docs] class AE(nn.Module): def __init__(self, **kwargs): super().__init__() self.conv1 = nn.Conv1d(1, 1, 5, padding=2) self.bn = torch.nn.BatchNorm1d(1, eps=0.001, momentum=0.99) self.encoder_hidden_layer = nn.Linear( in_features=kwargs["input_shape"], out_features=128 ) self.encoder_output_layer = nn.Linear(in_features=128, out_features=128) self.decoder_hidden_layer = nn.Linear(in_features=128, out_features=128) self.decoder_output_layer = nn.Linear( in_features=128, out_features=kwargs["input_shape"] )
[docs] def forward(self, features): features = self.conv1(features.reshape(features.shape[0], 1, features.shape[1])) features = torch.relu(self.bn(features)).reshape( features.shape[0], features.shape[2] ) activation = self.encoder_hidden_layer(features) activation = self.encoder_hidden_layer(features) activation = torch.relu(activation) code = self.encoder_output_layer(activation) code = torch.relu(code) activation = self.decoder_hidden_layer(code) activation = torch.relu(activation) reconstructed = self.decoder_output_layer(activation) # reconstructed = torch.relu(reconstructed) return code, reconstructed
[docs] class AutoSort: def __init__(self, ch_num, samplepoints, loc_dim, device, set_shank_id,save_dir, pos_weight_noise=None,pos_weight_label=None): self.clsfier_noise = clssimp((ch_num+1)*samplepoints+loc_dim , 2).to(device) self.clsfier_label = clssimp((ch_num+1)*samplepoints+loc_dim , len(set_shank_id)).to(device) self.optimizer = torch.optim.Adam([ {'params': self.clsfier_noise.parameters()}, {'params': self.clsfier_label.parameters()}, ], lr=1e-4) self.criterion = nn.MSELoss() self.bceloss = nn.BCEWithLogitsLoss(pos_weight=pos_weight_noise) self.bceloss_label = nn.BCEWithLogitsLoss(pos_weight=pos_weight_label) self.save_model_path_1 = save_dir+ 'multitask_single_wave_noise_ae.pth' self.save_model_path_2 = save_dir+ 'multitask_single_wave_clsfier_noise_clsfier.pth' self.save_model_path_3 = save_dir+ 'multitask_single_wave_clsfier_label_clsfier.pth' self.set_shank_id = set_shank_id
[docs] def save_model(self): torch.save(self.clsfier_noise.state_dict(), self.save_model_path_2) torch.save(self.clsfier_label.state_dict(), self.save_model_path_3)
[docs] def load_model(self): # self.model.load_state_dict(torch.load(save_model_path_1)) self.clsfier_noise.load_state_dict(torch.load(self.save_model_path_2)) self.clsfier_label.load_state_dict(torch.load(self.save_model_path_3))
[docs] def to_device(self, device): self.clsfier_noise.to(device) self.clsfier_label.to(device) self.bceloss.to(device) self.bceloss_label.to(device)
[docs] def train(self): self.clsfier_noise.train() self.clsfier_label.train()
[docs] def eval(self): self.clsfier_noise.eval() self.clsfier_label.eval()
[docs] def iter_model(self, batch_features, classify_labels, labels, single_waveform,pred_loc): self.optimizer.zero_grad() codes=batch_features codes = torch.cat((codes, single_waveform), axis=1) codes = torch.cat((codes, pred_loc), axis=1) cls_output = self.clsfier_noise(codes.float()) test = labels[:,1]==1 if sum(test)>1: cls_label_output = self.clsfier_label(codes.float()[test,:]) train_loss3 = 1000* self.bceloss_label(cls_label_output, classify_labels[test,:len(self.set_shank_id)]) else: train_loss3 = torch.tensor(0) train_loss1 = 0 train_loss2 = 1000* self.bceloss(cls_output, labels) train_loss = train_loss1 + train_loss2 + train_loss3 train_loss.backward() self.optimizer.step() return train_loss1, train_loss2.item(), train_loss3.item(), test
[docs] def iter_model_eval(self, batch_features, classify_labels, labels, single_waveform,pred_loc): codes=batch_features codes = torch.cat((codes, single_waveform), axis=1) codes = torch.cat((codes, pred_loc), axis=1) cls_output = self.clsfier_noise(codes.float()) gt = torch.argmax(labels, axis=1) pred = torch.argmax(cls_output, axis=1) test = labels[:,1]==1 if sum(test)>1: cls_label_output = self.clsfier_label(codes.float()[test,:]) pred_class = torch.argmax(cls_label_output,axis=1) gt_label_class = torch.argmax(classify_labels[test, :len(self.set_shank_id)], axis=1) train_loss3 = 1000* self.bceloss_label(cls_label_output, classify_labels[test,:len(self.set_shank_id)]) else: train_loss3 = torch.tensor(0) gt_label_class=torch.tensor([]) pred_class=torch.tensor([]) train_loss1 = 0 train_loss2 = 1000* self.bceloss(cls_output, labels) train_loss = train_loss1 + train_loss2 + train_loss3 return train_loss1, train_loss2.item(), train_loss3.item(), gt, pred, gt_label_class, pred_class
[docs] def iter_model_eval_umap(self, batch_features, classify_labels, labels, single_waveform,pred_loc): codes=batch_features codes = torch.cat((codes, single_waveform), axis=1) codes = torch.cat((codes, pred_loc), axis=1) cls_output = self.clsfier_noise(codes.float()) codestest = self.clsfier_noise.intermediate_forward(codes.float()) gt = torch.argmax(labels, axis=1) pred = torch.argmax(cls_output, axis=1) # test = cls_output[:,1]==1 # if sum(test)>1: cls_label_output = self.clsfier_label(codes.float()) codestest_label = self.clsfier_label.intermediate_forward(codes.float()) pred_class = torch.argmax(cls_label_output,axis=1) gt_label_class = torch.argmax(classify_labels[:, :len(self.set_shank_id)], axis=1) prob = nnf.softmax(cls_label_output, dim=1) return gt, pred, gt_label_class, pred_class, codestest, codestest_label, prob