Add custom_objects.py for tensorflowtools compatibility
Browse files- custom_objects.py +57 -0
custom_objects.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras.saving import register_keras_serializable
|
| 3 |
+
from tensorflow.keras import layers, models, backend as K
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
@register_keras_serializable()
|
| 7 |
+
def cold_temp_penalty(inputs):
|
| 8 |
+
temp = inputs[:, 0]
|
| 9 |
+
penalty = tf.where(
|
| 10 |
+
temp > 295.0,
|
| 11 |
+
1.0,
|
| 12 |
+
tf.where(
|
| 13 |
+
temp < 290.0,
|
| 14 |
+
0.0,
|
| 15 |
+
(temp - 290.0) / 5.0
|
| 16 |
+
)
|
| 17 |
+
)
|
| 18 |
+
return penalty[:, None]
|
| 19 |
+
|
| 20 |
+
@register_keras_serializable()
|
| 21 |
+
def fire_risk_booster(inputs):
|
| 22 |
+
temp = inputs[:, 0]
|
| 23 |
+
humidity = inputs[:, 1]
|
| 24 |
+
wind = inputs[:, 2]
|
| 25 |
+
veg = inputs[:, 3]
|
| 26 |
+
|
| 27 |
+
# Boost ranges
|
| 28 |
+
temp_boost = tf.sigmoid((temp - 305.0) * 1.2)
|
| 29 |
+
humidity_boost = tf.sigmoid((20.0 - humidity) * 0.5)
|
| 30 |
+
wind_boost = tf.sigmoid((wind - 15.0) * 0.8)
|
| 31 |
+
veg_boost = tf.sigmoid((veg - 70.0) * 0.5)
|
| 32 |
+
|
| 33 |
+
# Combine and scale
|
| 34 |
+
combined = temp_boost * humidity_boost * wind_boost * veg_boost
|
| 35 |
+
boost = 1.0 + 0.3 * combined # Up to 30% increase in fire score
|
| 36 |
+
return boost[:, None]
|
| 37 |
+
|
| 38 |
+
@register_keras_serializable()
|
| 39 |
+
def fire_suppression_mask(inputs):
|
| 40 |
+
temp = inputs[:, 0]
|
| 41 |
+
humidity = inputs[:, 1]
|
| 42 |
+
wind = inputs[:, 2]
|
| 43 |
+
|
| 44 |
+
# Suppress if warm but humid and still
|
| 45 |
+
temp_flag = tf.sigmoid((temp - 293.0) * 1.2)
|
| 46 |
+
humid_flag = tf.sigmoid((humidity - 50.0) * 0.4)
|
| 47 |
+
wind_flag = 1 - tf.sigmoid((wind - 5.0) * 0.8)
|
| 48 |
+
|
| 49 |
+
suppression = temp_flag * humid_flag * wind_flag
|
| 50 |
+
penalty = 1.0 - 0.3 * suppression # Max 30% suppression
|
| 51 |
+
return penalty[:, None]
|
| 52 |
+
|
| 53 |
+
CUSTOM_OBJECTS = {
|
| 54 |
+
"cold_temp_penalty": cold_temp_penalty,
|
| 55 |
+
"fire_risk_booster": fire_risk_booster,
|
| 56 |
+
"fire_suppression_mask": fire_suppression_mask
|
| 57 |
+
}
|