Harshith Reddy commited on
Commit
af587ff
Β·
1 Parent(s): 7fa06ba

Fixed preprocessing

Browse files
Files changed (1) hide show
  1. processing.py +49 -4
processing.py CHANGED
@@ -58,6 +58,16 @@ def preprocess_nifti(file_path, device=None):
58
  raw_data = nifti_img.get_fdata(dtype=np.float32)
59
  print(f" β†’ Raw data stats: min={raw_data.min():.4f}, max={raw_data.max():.4f}, mean={raw_data.mean():.4f}, std={raw_data.std():.4f}")
60
 
 
 
 
 
 
 
 
 
 
 
61
  use_enhanced_preprocessing = os.environ.get("USE_ENHANCED_PREPROCESSING", "false").lower() == "true"
62
 
63
  if use_enhanced_preprocessing:
@@ -80,15 +90,28 @@ def preprocess_nifti(file_path, device=None):
80
  transform = transforms.Compose([
81
  transforms.LoadImaged(keys=["image"]),
82
  transforms.EnsureChannelFirstD(keys=["image"], channel_dim="no_channel"),
83
- transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
84
  transforms.ToTensord(keys=["image"])
85
  ])
86
  print(" β†’ Using training-matched preprocessing (for optimal accuracy)")
87
 
88
  data = {"image": file_path}
89
  print("Applying transforms...")
90
- augmented = transform(data)
91
- image_data = augmented["image"]
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  if not isinstance(image_data, torch.Tensor):
94
  image_data = torch.from_numpy(np.array(image_data))
@@ -98,8 +121,30 @@ def preprocess_nifti(file_path, device=None):
98
 
99
  img_np = image_data.numpy() if not hasattr(image_data, 'device') or image_data.device.type == 'cpu' else image_data.cpu().numpy()
100
  vmin, vmax = float(img_np.min()), float(img_np.max())
 
101
  if vmax - vmin < 1e-6:
102
- raise ValueError(f"Preprocessing produced near-constant image: min={vmin:.6f}, max={vmax:.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  print(f" β†’ After transforms: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
104
 
105
  if device is not None and device.type == 'cuda':
 
58
  raw_data = nifti_img.get_fdata(dtype=np.float32)
59
  print(f" β†’ Raw data stats: min={raw_data.min():.4f}, max={raw_data.max():.4f}, mean={raw_data.mean():.4f}, std={raw_data.std():.4f}")
60
 
61
+ if raw_data.max() - raw_data.min() < 1e-6:
62
+ raise ValueError(f"Input NIfTI file contains constant values (min=max={raw_data.min():.4f}). Cannot process.")
63
+
64
+ nonzero_mask = raw_data > 1e-6
65
+ nonzero_count = nonzero_mask.sum()
66
+ total_count = raw_data.size
67
+ nonzero_ratio = nonzero_count / total_count if total_count > 0 else 0.0
68
+
69
+ print(f" β†’ Non-zero voxels: {nonzero_count:,} / {total_count:,} ({100*nonzero_ratio:.2f}%)")
70
+
71
  use_enhanced_preprocessing = os.environ.get("USE_ENHANCED_PREPROCESSING", "false").lower() == "true"
72
 
73
  if use_enhanced_preprocessing:
 
90
  transform = transforms.Compose([
91
  transforms.LoadImaged(keys=["image"]),
92
  transforms.EnsureChannelFirstD(keys=["image"], channel_dim="no_channel"),
93
+ transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True, eps=1e-8),
94
  transforms.ToTensord(keys=["image"])
95
  ])
96
  print(" β†’ Using training-matched preprocessing (for optimal accuracy)")
97
 
98
  data = {"image": file_path}
99
  print("Applying transforms...")
100
+
101
+ try:
102
+ augmented = transform(data)
103
+ image_data = augmented["image"]
104
+ except Exception as e:
105
+ print(f" ⚠ Transform failed: {e}. Trying fallback preprocessing...")
106
+ try:
107
+ raw_data_norm = (raw_data - raw_data.min()) / (raw_data.max() - raw_data.min() + 1e-8)
108
+ if raw_data_norm.std() < 1e-6:
109
+ raise ValueError("Normalized data is still constant")
110
+ image_data = torch.from_numpy(raw_data_norm).float()
111
+ image_data = image_data.unsqueeze(0)
112
+ print(" β†’ Used fallback normalization (min-max scaling)")
113
+ except Exception as e2:
114
+ raise ValueError(f"Both standard and fallback preprocessing failed: {e2}")
115
 
116
  if not isinstance(image_data, torch.Tensor):
117
  image_data = torch.from_numpy(np.array(image_data))
 
121
 
122
  img_np = image_data.numpy() if not hasattr(image_data, 'device') or image_data.device.type == 'cpu' else image_data.cpu().numpy()
123
  vmin, vmax = float(img_np.min()), float(img_np.max())
124
+
125
  if vmax - vmin < 1e-6:
126
+ print(f" ⚠ WARNING: Preprocessing produced near-constant image (min={vmin:.6f}, max={vmax:.6f}). Trying alternative preprocessing...")
127
+ try:
128
+ if nonzero_ratio > 0.01:
129
+ nonzero_mean = raw_data[nonzero_mask].mean()
130
+ nonzero_std = raw_data[nonzero_mask].std() + 1e-8
131
+ raw_data_norm = np.zeros_like(raw_data)
132
+ raw_data_norm[nonzero_mask] = (raw_data[nonzero_mask] - nonzero_mean) / nonzero_std
133
+ raw_data_norm = (raw_data_norm - raw_data_norm.min()) / (raw_data_norm.max() - raw_data_norm.min() + 1e-8)
134
+ else:
135
+ raw_data_norm = (raw_data - raw_data.min()) / (raw_data.max() - raw_data.min() + 1e-8)
136
+
137
+ if raw_data_norm.std() < 1e-6:
138
+ raise ValueError("Alternative normalization also produced constant data")
139
+
140
+ image_data = torch.from_numpy(raw_data_norm).float()
141
+ image_data = image_data.unsqueeze(0)
142
+ img_np = image_data.numpy()
143
+ vmin, vmax = float(img_np.min()), float(img_np.max())
144
+ print(f" β†’ Alternative preprocessing successful: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
145
+ except Exception as e3:
146
+ raise ValueError(f"Preprocessing produced near-constant image: min={vmin:.6f}, max={vmax:.6f}. Alternative preprocessing also failed: {e3}")
147
+
148
  print(f" β†’ After transforms: min={vmin:.4f}, max={vmax:.4f}, mean={img_np.mean():.4f}, std={img_np.std():.4f}")
149
 
150
  if device is not None and device.type == 'cuda':