Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| import tensorflow as tf | |
| from utils import _get_retina_bb, _pad_to_square | |
| def load_model(model_file): | |
| model = tf.keras.models.load_model(model_file, compile=False) | |
| print(f'Model {model_file} Loaded!') | |
| return model | |
| def load_gatekeeper(): | |
| validator_model = tf.keras.models.load_model('checkpoints/ResNetV2-EyeQ-QA.tf') | |
| print('Gatekeeper Model Loaded!') | |
| return validator_model | |
| def parse_function(image): | |
| image = tf.image.resize(image, [512, 512]) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| return image | |
| def main(): | |
| st.title('Retina Segmentation') | |
| st.sidebar.title('Segmentation Model') | |
| options = st.sidebar.selectbox('Select Option:', ('Vessels', 'Lesions (BETA)')) | |
| gatekeeper = st.sidebar.radio("Gatekeeper:", ('Enabled', 'Disabled')) | |
| gatekeeper_model = load_gatekeeper() | |
| if options == 'Vessels': | |
| st.set_option('deprecation.showfileUploaderEncoding', False) | |
| uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg')) | |
| model = load_model('checkpoints/DeeplabV3Plus_DRIVE.tf') | |
| if uploaded_file: | |
| col1, col2 = st.columns(2) | |
| # Load Image | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, 1) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Check image | |
| valid = np.argmax(gatekeeper_model(parse_function(image[None, ...]))) | |
| if valid == 2 and gatekeeper == 'Enabled': | |
| st.image(image) | |
| st.info('Image is of poor quality') | |
| return | |
| # Localise and center retina image | |
| x, y, w, h, _ = _get_retina_bb(image) | |
| image = image[y:y + h, x:x + w, :] | |
| image = _pad_to_square(image, border=0) | |
| image = cv2.resize(image, (1024, 1024)) | |
| with col1: | |
| st.subheader("Uploaded Image") | |
| st.image(image) | |
| # Apply CLAHE pre-processing | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16)) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
| image[:, :, 0] = clahe.apply(image[:, :, 0]) | |
| image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| # Run model on input | |
| y_pred = model(image[None, ..., None])[0].numpy() | |
| with col2: | |
| st.subheader("Predicted Vessel") | |
| st.image(y_pred) | |
| elif options == 'Lesions (BETA)': | |
| st.write('```--- WARNING: This model is highly experimental ---```') | |
| st.set_option('deprecation.showfileUploaderEncoding', False) | |
| uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg')) | |
| model = load_model('checkpoints/DeeplabV3Plus_FGADR.tf') | |
| if uploaded_file: | |
| col1, col2, col3, = st.columns(3) | |
| # Load Image | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, 1) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Check image | |
| valid = np.argmax(gatekeeper_model(parse_function(image[None, ...]))) | |
| if valid == 2 and gatekeeper == 'Enabled': | |
| st.image(image) | |
| st.info('Image is of poor quality') | |
| return | |
| # Localise and center retina image | |
| x, y, w, h, _ = _get_retina_bb(image) | |
| image = image[y:y + h, x:x + w, :] | |
| image = _pad_to_square(image, border=0) | |
| image = cv2.resize(image, (1024, 1024)) | |
| with col1: | |
| st.subheader("Uploaded Image") | |
| st.image(image) | |
| # Apply CLAHE pre-processing | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16)) | |
| image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
| image[:, :, 0] = clahe.apply(image[:, :, 0]) | |
| image = cv2.cvtColor(image, cv2.COLOR_LAB2RGB) | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| # Run model on input | |
| y_pred = model(image[None, ..., None])[0].numpy() | |
| with col2: | |
| st.subheader(f'MA') | |
| st.image(y_pred[..., 1]) | |
| with col3: | |
| st.subheader(f'HE') | |
| st.image(y_pred[..., 2]) | |
| with col1: | |
| st.subheader(f'EX') | |
| st.image(y_pred[..., 3]) | |
| with col2: | |
| st.subheader(f'SE') | |
| st.image(y_pred[..., 4]) | |
| with col3: | |
| st.subheader(f'OD') | |
| st.image(y_pred[..., 5]) | |
| if __name__ == '__main__': | |
| tf.config.set_visible_devices([], 'GPU') | |
| main() | |