Harshith Reddy
commited on
Commit
Β·
af587ff
1
Parent(s):
7fa06ba
Fixed preprocessing
Browse files- processing.py +49 -4
processing.py
CHANGED
|
@@ -58,6 +58,16 @@ def preprocess_nifti(file_path, device=None):
|
|
| 58 |
raw_data = nifti_img.get_fdata(dtype=np.float32)
|
| 59 |
print(f" β Raw data stats: min={raw_data.min():.4f}, max={raw_data.max():.4f}, mean={raw_data.mean():.4f}, std={raw_data.std():.4f}")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
use_enhanced_preprocessing = os.environ.get("USE_ENHANCED_PREPROCESSING", "false").lower() == "true"
|
| 62 |
|
| 63 |
if use_enhanced_preprocessing:
|
|
@@ -80,15 +90,28 @@ def preprocess_nifti(file_path, device=None):
|
|
| 80 |
transform = transforms.Compose([
|
| 81 |
transforms.LoadImaged(keys=["image"]),
|
| 82 |
transforms.EnsureChannelFirstD(keys=["image"], channel_dim="no_channel"),
|
| 83 |
-
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
|
| 84 |
transforms.ToTensord(keys=["image"])
|
| 85 |
])
|
| 86 |
print(" β Using training-matched preprocessing (for optimal accuracy)")
|
| 87 |
|
| 88 |
data = {"image": file_path}
|
| 89 |
print("Applying transforms...")
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
if not isinstance(image_data, torch.Tensor):
|
| 94 |
image_data = torch.from_numpy(np.array(image_data))
|
|
@@ -98,8 +121,30 @@ def preprocess_nifti(file_path, device=None):
|
|
| 98 |
|
| 99 |
img_np = image_data.numpy() if not hasattr(image_data, 'device') or image_data.device.type == 'cpu' else image_data.cpu().numpy()
|
| 100 |
vmin, vmax = float(img_np.min()), float(img_np.max())
|
|
|
|
| 101 |
if vmax - vmin < 1e-6:
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
print(f" β After transforms: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
|
| 104 |
|
| 105 |
if device is not None and device.type == 'cuda':
|
|
|
|
| 58 |
raw_data = nifti_img.get_fdata(dtype=np.float32)
|
| 59 |
print(f" β Raw data stats: min={raw_data.min():.4f}, max={raw_data.max():.4f}, mean={raw_data.mean():.4f}, std={raw_data.std():.4f}")
|
| 60 |
|
| 61 |
+
if raw_data.max() - raw_data.min() < 1e-6:
|
| 62 |
+
raise ValueError(f"Input NIfTI file contains constant values (min=max={raw_data.min():.4f}). Cannot process.")
|
| 63 |
+
|
| 64 |
+
nonzero_mask = raw_data > 1e-6
|
| 65 |
+
nonzero_count = nonzero_mask.sum()
|
| 66 |
+
total_count = raw_data.size
|
| 67 |
+
nonzero_ratio = nonzero_count / total_count if total_count > 0 else 0.0
|
| 68 |
+
|
| 69 |
+
print(f" β Non-zero voxels: {nonzero_count:,} / {total_count:,} ({100*nonzero_ratio:.2f}%)")
|
| 70 |
+
|
| 71 |
use_enhanced_preprocessing = os.environ.get("USE_ENHANCED_PREPROCESSING", "false").lower() == "true"
|
| 72 |
|
| 73 |
if use_enhanced_preprocessing:
|
|
|
|
| 90 |
transform = transforms.Compose([
|
| 91 |
transforms.LoadImaged(keys=["image"]),
|
| 92 |
transforms.EnsureChannelFirstD(keys=["image"], channel_dim="no_channel"),
|
| 93 |
+
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True, eps=1e-8),
|
| 94 |
transforms.ToTensord(keys=["image"])
|
| 95 |
])
|
| 96 |
print(" β Using training-matched preprocessing (for optimal accuracy)")
|
| 97 |
|
| 98 |
data = {"image": file_path}
|
| 99 |
print("Applying transforms...")
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
augmented = transform(data)
|
| 103 |
+
image_data = augmented["image"]
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f" β Transform failed: {e}. Trying fallback preprocessing...")
|
| 106 |
+
try:
|
| 107 |
+
raw_data_norm = (raw_data - raw_data.min()) / (raw_data.max() - raw_data.min() + 1e-8)
|
| 108 |
+
if raw_data_norm.std() < 1e-6:
|
| 109 |
+
raise ValueError("Normalized data is still constant")
|
| 110 |
+
image_data = torch.from_numpy(raw_data_norm).float()
|
| 111 |
+
image_data = image_data.unsqueeze(0)
|
| 112 |
+
print(" β Used fallback normalization (min-max scaling)")
|
| 113 |
+
except Exception as e2:
|
| 114 |
+
raise ValueError(f"Both standard and fallback preprocessing failed: {e2}")
|
| 115 |
|
| 116 |
if not isinstance(image_data, torch.Tensor):
|
| 117 |
image_data = torch.from_numpy(np.array(image_data))
|
|
|
|
| 121 |
|
| 122 |
img_np = image_data.numpy() if not hasattr(image_data, 'device') or image_data.device.type == 'cpu' else image_data.cpu().numpy()
|
| 123 |
vmin, vmax = float(img_np.min()), float(img_np.max())
|
| 124 |
+
|
| 125 |
if vmax - vmin < 1e-6:
|
| 126 |
+
print(f" β WARNING: Preprocessing produced near-constant image (min={vmin:.6f}, max={vmax:.6f}). Trying alternative preprocessing...")
|
| 127 |
+
try:
|
| 128 |
+
if nonzero_ratio > 0.01:
|
| 129 |
+
nonzero_mean = raw_data[nonzero_mask].mean()
|
| 130 |
+
nonzero_std = raw_data[nonzero_mask].std() + 1e-8
|
| 131 |
+
raw_data_norm = np.zeros_like(raw_data)
|
| 132 |
+
raw_data_norm[nonzero_mask] = (raw_data[nonzero_mask] - nonzero_mean) / nonzero_std
|
| 133 |
+
raw_data_norm = (raw_data_norm - raw_data_norm.min()) / (raw_data_norm.max() - raw_data_norm.min() + 1e-8)
|
| 134 |
+
else:
|
| 135 |
+
raw_data_norm = (raw_data - raw_data.min()) / (raw_data.max() - raw_data.min() + 1e-8)
|
| 136 |
+
|
| 137 |
+
if raw_data_norm.std() < 1e-6:
|
| 138 |
+
raise ValueError("Alternative normalization also produced constant data")
|
| 139 |
+
|
| 140 |
+
image_data = torch.from_numpy(raw_data_norm).float()
|
| 141 |
+
image_data = image_data.unsqueeze(0)
|
| 142 |
+
img_np = image_data.numpy()
|
| 143 |
+
vmin, vmax = float(img_np.min()), float(img_np.max())
|
| 144 |
+
print(f" β Alternative preprocessing successful: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
|
| 145 |
+
except Exception as e3:
|
| 146 |
+
raise ValueError(f"Preprocessing produced near-constant image: min={vmin:.6f}, max={vmax:.6f}. Alternative preprocessing also failed: {e3}")
|
| 147 |
+
|
| 148 |
print(f" β After transforms: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
|
| 149 |
|
| 150 |
if device is not None and device.type == 'cuda':
|