1688 字
8 分钟

LA Train

2026-03-10
浏览量 加载中...
温馨提示 最后更新于 2026-03-10,距今已过 58 天。部分内容可能已过时

BraTS2019 dataset

brats.py
import os
import torch
import numpy as np
from glob import glob
from torch.utils.data import Dataset
import h5py
import itertools
from torch.utils.data.sampler import Sampler
class BraTS2019(Dataset):
""" BraTS2019 Dataset """
def __init__(self, base_dir=None, split='train', num=None, transform=None):
self._base_dir = base_dir
self.transform = transform
self.sample_list = []
train_path = self._base_dir+'/train.txt'
test_path = self._base_dir+'/val.txt'
if split == 'train':
with open(train_path, 'r') as f:
self.image_list = f.readlines()
elif split == 'test':
with open(test_path, 'r') as f:
self.image_list = f.readlines()
self.image_list = [item.replace('\n', '').split(",")[0] for item in self.image_list]
if num is not None:
self.image_list = self.image_list[:num]
print("total {} samples".format(len(self.image_list)))
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_name = self.image_list[idx]
h5f = h5py.File(self._base_dir + "/data/{}.h5".format(image_name), 'r')
image = h5f['image'][:]
label = h5f['label'][:]
sample = {'image': image, 'label': label.astype(np.uint8)}
if self.transform:
sample = self.transform(sample)
return sample
class CenterCrop(object):
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label']
# pad the sample if necessary
if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
self.output_size[2]:
pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)],
mode='constant', constant_values=0)
label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)],
mode='constant', constant_values=0)
(w, h, d) = image.shape
w1 = int(round((w - self.output_size[0]) / 2.))
h1 = int(round((h - self.output_size[1]) / 2.))
d1 = int(round((d - self.output_size[2]) / 2.))
label = label[w1:w1 + self.output_size[0], h1:h1 +
self.output_size[1], d1:d1 + self.output_size[2]]
image = image[w1:w1 + self.output_size[0], h1:h1 +
self.output_size[1], d1:d1 + self.output_size[2]]
return {'image': image, 'label': label}
class RandomCrop(object):
"""
Crop randomly the image in a sample
Args:
output_size (int): Desired output size
"""
def __init__(self, output_size, with_sdf=False):
self.output_size = output_size
self.with_sdf = with_sdf
def __call__(self, sample):
image, label = sample['image'], sample['label']
if self.with_sdf:
sdf = sample['sdf']
# pad the sample if necessary
if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
self.output_size[2]:
pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)],
mode='constant', constant_values=0)
label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)],
mode='constant', constant_values=0)
if self.with_sdf:
sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)],
mode='constant', constant_values=0)
(w, h, d) = image.shape
# if np.random.uniform() > 0.33:
# w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4)
# h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4)
# else:
w1 = np.random.randint(0, w - self.output_size[0])
h1 = np.random.randint(0, h - self.output_size[1])
d1 = np.random.randint(0, d - self.output_size[2])
label = label[w1:w1 + self.output_size[0], h1:h1 +
self.output_size[1], d1:d1 + self.output_size[2]]
image = image[w1:w1 + self.output_size[0], h1:h1 +
self.output_size[1], d1:d1 + self.output_size[2]]
if self.with_sdf:
sdf = sdf[w1:w1 + self.output_size[0], h1:h1 +
self.output_size[1], d1:d1 + self.output_size[2]]
return {'image': image, 'label': label, 'sdf': sdf}
else:
return {'image': image, 'label': label}
class RandomRotFlip(object):
"""
Crop randomly flip the dataset in a sample
Args:
output_size (int): Desired output size
"""
def __call__(self, sample):
image, label = sample['image'], sample['label']
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return {'image': image, 'label': label}
class RandomNoise(object):
def __init__(self, mu=0, sigma=0.1):
self.mu = mu
self.sigma = sigma
def __call__(self, sample):
image, label = sample['image'], sample['label']
noise = np.clip(self.sigma * np.random.randn(
image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma)
noise = noise + self.mu
image = image + noise
return {'image': image, 'label': label}
class CreateOnehotLabel(object):
def __init__(self, num_classes):
self.num_classes = num_classes
def __call__(self, sample):
image, label = sample['image'], sample['label']
onehot_label = np.zeros(
(self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32)
for i in range(self.num_classes):
onehot_label[i, :, :, :] = (label == i).astype(np.float32)
return {'image': image, 'label': label, 'onehot_label': onehot_label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image = sample['image']
image = image.reshape(
1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
if 'onehot_label' in sample:
return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(),
'onehot_label': torch.from_numpy(sample['onehot_label']).long()}
else:
return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
return np.random.permutation(iterable)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)

计算模型参数量和FLOPs

test_ACDC.py
import os
import argparse
import torch
import numpy as np
import h5py
import nibabel as nib
from tqdm import tqdm
from medpy import metric
from networks.net_factory import net_factory
from utils import val_2d
from thop import profile, clever_format
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, default='../data/ACDC', help='Root path of dataset')
parser.add_argument('--exp', type=str, default='CPAM', help='Experiment name')
parser.add_argument('--model', type=str, default='unet', help='Model name')
parser.add_argument('--gpu', type=str, default='0', help='GPU to use')
parser.add_argument('--detail', type=int, default=1, help='Print metrics for every sample?')
parser.add_argument('--nms', type=int, default=1, help='Apply NMS (LargestCC) post-processing?')
parser.add_argument('--labelnum', type=int, default=3, help='Number of labeled patients used in training')
parser.add_argument('--stage_name', type=str, default='self_train', help='Stage: self_train or pre_train')
FLAGS = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
# ── Paths (mirror the training script conventions) ──────────────────────────
snapshot_path = "./model/CPAM/ACDC_{}_{}_labeled/{}".format(
FLAGS.exp, FLAGS.labelnum, FLAGS.stage_name)
test_save_path = "./model/CPAM/ACDC_{}_{}_labeled/{}_predictions/".format(
FLAGS.exp, FLAGS.labelnum, FLAGS.model)
num_classes = 4 # background(0) + RV(1) + Myo(2) + LV(3)
if not os.path.exists(test_save_path):
os.makedirs(test_save_path)
print("Predictions will be saved to:", test_save_path)
# ── Load test list ────────────────────────────────────────────────────────────
# test.list entries look like: "patient001_frame01" (one per line)
with open(os.path.join(FLAGS.root_path, 'test.list'), 'r') as f:
image_list = [line.strip() for line in f.readlines() if line.strip()]
# Full path to each patient .h5 file, e.g. ../data/ACDC/data/patient001_frame01.h5
image_list = [os.path.join(FLAGS.root_path, "data", item + ".h5")
for item in image_list]
print("Total test cases:", len(image_list))
# ── NIfTI saving ─────────────────────────────────────────────────────────────
def save_nii(array, save_path, dtype=np.float32):
"""Save a numpy array (D, H, W) as a .nii.gz file."""
nii_img = nib.Nifti1Image(array.astype(dtype), affine=np.eye(4))
nib.save(nii_img, save_path)
# ── Metric helpers ────────────────────────────────────────────────────────────
def calculate_metric_percase(pred, gt):
"""Return (dice, jc, hd, asd) for a single binary foreground mask."""
if pred.sum() > 0 and gt.sum() > 0:
dice = metric.binary.dc(pred, gt)
jc = metric.binary.jc(pred, gt)
hd = metric.binary.hd95(pred, gt)
asd = metric.binary.asd(pred, gt)
elif pred.sum() == 0 and gt.sum() == 0:
dice, jc, hd, asd = 1.0, 1.0, 0.0, 0.0
else:
dice, jc, hd, asd = 0.0, 0.0, 100.0, 100.0
return dice, jc, hd, asd
# ── Model complexity ──────────────────────────────────────────────────────────
def compute_model_complexity(model, input_size=(1, 1, 256, 256)):
"""
Print parameter count and FLOPs for a single 2-D slice (B=1, C=1, H, W).
Uses thop if available, otherwise falls back to manual param counting.
"""
print("\n" + "=" * 50)
print("Model Complexity")
print("-" * 50)
# ── Parameters ──────────────────────────────────────────────────────────
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Total params : {total_params:,} "
f"({total_params / 1e6:.2f} M)")
print(f" Trainable params: {trainable_params:,} "
f"({trainable_params / 1e6:.2f} M)")
# ── FLOPs via thop ───────────────────────────────────────────────────────
dummy = torch.randn(input_size).cuda()
model.eval()
with torch.no_grad():
try:
flops, _ = profile(model, inputs=(dummy,), verbose=False)
flops_str, _ = clever_format([flops, total_params], "%.3f")
print(f" FLOPs (per slice): {flops_str} "
f"({flops / 1e9:.3f} GFLOPs)")
except Exception as e:
print(f" FLOPs calculation failed: {e}")
print("=" * 50 + "\n")
# ── Main evaluation ───────────────────────────────────────────────────────────
def test_calculate_metric():
# Load model
model = net_factory(net_type=FLAGS.model, in_chns=1,
class_num=num_classes, mode="train")
save_model_path = os.path.join(snapshot_path,
'{}_best_model.pth'.format(FLAGS.model))
model.load_state_dict(torch.load(save_model_path))
print("Loaded weights from:", save_model_path)
model.eval()
# ── Parameters & FLOPs ───────────────────────────────────────────────────
compute_model_complexity(model, input_size=(1, 1, 256, 256))
# Accumulate metrics across all test cases (3 foreground classes)
# Shape: [num_cases, num_fg_classes, 4_metrics]
all_metrics = []
for case_path in tqdm(image_list, desc="Testing"):
# ── Read h5 ──────────────────────────────────────────────────────────
with h5py.File(case_path, 'r') as f:
image = f['image'][:] # (H, W, D) or (D, H, W) — adjust if needed
label = f['label'][:] # same shape, integer class labels
# ── Save dir ─────────────────────────────────────────────────────────
case_name = os.path.basename(case_path).replace('.h5', '')
case_save_dir = os.path.join(test_save_path, case_name)
os.makedirs(case_save_dir, exist_ok=True)
# ── Inference: build full prediction volume slice-by-slice ────────────
from scipy.ndimage import zoom as scipy_zoom
patch_size = [256, 256]
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slc = image[ind, :, :]
x, y = slc.shape
slc_resized = scipy_zoom(slc, (patch_size[0] / x, patch_size[1] / y), order=0)
inp = torch.from_numpy(slc_resized).unsqueeze(0).unsqueeze(0).float().cuda()
with torch.no_grad():
output = model(inp)
if len(output) > 1:
output = output[0]
out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
out = out.cpu().numpy()
pred_slc = scipy_zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
prediction[ind] = pred_slc
# ── Save .nii.gz ─────────────────────────────────────────────────────
save_nii(image, os.path.join(case_save_dir, "image.nii.gz"), dtype=np.float32)
save_nii(label, os.path.join(case_save_dir, "gt.nii.gz"), dtype=np.uint8)
save_nii(prediction, os.path.join(case_save_dir, "pred.nii.gz"), dtype=np.uint8)
# ── Compute 4 metrics per foreground class ────────────────────────────
metric_i = []
for c in range(1, num_classes):
pred_c = (prediction == c)
gt_c = (label == c)
if pred_c.sum() > 0 and gt_c.sum() > 0:
dice = metric.binary.dc(pred_c, gt_c)
jc = metric.binary.jc(pred_c, gt_c)
hd95 = metric.binary.hd95(pred_c, gt_c)
asd = metric.binary.asd(pred_c, gt_c)
elif pred_c.sum() == 0 and gt_c.sum() == 0:
dice, jc, hd95, asd = 1.0, 1.0, 0.0, 0.0
else:
dice, jc, hd95, asd = 0.0, 0.0, 100.0, 100.0
metric_i.append((dice, jc, hd95, asd))
if FLAGS.detail:
for c_idx, (dice, jc, hd95, asd) in enumerate(metric_i, start=1):
print(f" {case_name} class {c_idx}: "
f"Dice={dice:.4f} Jaccard={jc:.4f} "
f"HD95={hd95:.2f} ASD={asd:.2f}")
all_metrics.append(metric_i) # list of 3 tuples
# ── Aggregate ─────────────────────────────────────────────────────────────
all_metrics = np.array(all_metrics) # (N, 3, 4)
mean_metrics = all_metrics.mean(axis=0) # (3, 4)
class_names = ['RV (class 1)', 'Myo (class 2)', 'LV (class 3)']
print("\n" + "=" * 70)
print(f"{'Class':<20} {'Dice':>8} {'Jaccard':>10} {'HD95':>8} {'ASD':>8}")
print("-" * 70)
for name, m in zip(class_names, mean_metrics):
print(f"{name:<20} {m[0]:>8.4f} {m[1]:>10.4f} {m[2]:>8.2f} {m[3]:>8.2f}")
print("=" * 70)
mean_dice = mean_metrics[:, 0].mean()
mean_jaccard = mean_metrics[:, 1].mean()
mean_hd95 = mean_metrics[:, 2].mean()
mean_asd = mean_metrics[:, 3].mean()
print(f"{'Mean':<20} {mean_dice:>8.4f} {mean_jaccard:>10.4f} "
f"{mean_hd95:>8.2f} {mean_asd:>8.2f}")
return mean_metrics
if __name__ == '__main__':
metric_result = test_calculate_metric()
print("\nFinal mean metrics (Dice / Jaccard / HD95 / ASD):")
print(metric_result)
# ── Example usage ─────────────────────────────────────────────────────────────
# python ACDC_CPAM_test.py --model unet --labelnum 3 --stage_name self_train --gpu 0
# python ACDC_CPAM_test.py --model unet --labelnum 7 --stage_name self_train --gpu 0 --detail 0
LA Train
https://louaq.io/posts/code/train/
作者
Louaq
发布于
2026-03-10
许可协议
CC BY-NC-SA 4.0

评论区

目录