Charlie Li
commited on
Commit
·
17f8269
1
Parent(s):
396546f
update page
Browse files
app.py
CHANGED
|
@@ -39,10 +39,7 @@ captions = [
|
|
| 39 |
"you",
|
| 40 |
"letter",
|
| 41 |
]
|
| 42 |
-
gif_base64_strings = {
|
| 43 |
-
caption: get_base64_encoded_gif(f"gifs/{name}")
|
| 44 |
-
for caption, name in zip(captions, gif_filenames)
|
| 45 |
-
}
|
| 46 |
|
| 47 |
sketches = [
|
| 48 |
"bird.gif",
|
|
@@ -50,21 +47,23 @@ sketches = [
|
|
| 50 |
"coffee.gif",
|
| 51 |
"penguin.gif",
|
| 52 |
]
|
| 53 |
-
sketches_base64_strings = {
|
| 54 |
-
name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches
|
| 55 |
-
}
|
| 56 |
|
| 57 |
if not pre_generate:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
else:
|
| 69 |
pregenerate_videos(video_cache_dir=video_cache_dir)
|
| 70 |
print("Videos cached.")
|
|
@@ -143,14 +142,21 @@ def demo(Dataset, Model):
|
|
| 143 |
|
| 144 |
with gr.Blocks() as app:
|
| 145 |
gr.HTML(org_content)
|
| 146 |
-
gr.Markdown(
|
| 147 |
-
"# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write"
|
| 148 |
-
)
|
| 149 |
gr.HTML(
|
| 150 |
"""
|
| 151 |
-
<div style="display: flex;
|
| 152 |
-
<a href="https://arxiv.org/
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
</a>
|
| 155 |
</div>
|
| 156 |
"""
|
|
@@ -163,9 +169,7 @@ with gr.Blocks() as app:
|
|
| 163 |
"""
|
| 164 |
)
|
| 165 |
with gr.Row():
|
| 166 |
-
dataset = gr.Dropdown(
|
| 167 |
-
["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM"
|
| 168 |
-
)
|
| 169 |
model = gr.Dropdown(
|
| 170 |
["Small-i", "Large-i", "Small-p"],
|
| 171 |
label="InkSight Model Variant",
|
|
@@ -179,18 +183,12 @@ with gr.Blocks() as app:
|
|
| 179 |
# vanilla_img = gr.Image(label="Vanilla")
|
| 180 |
|
| 181 |
with gr.Row():
|
| 182 |
-
d_t_text = gr.Textbox(
|
| 183 |
-
label="OCR recognition input to the model", interactive=False
|
| 184 |
-
)
|
| 185 |
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
|
| 186 |
vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
|
| 187 |
with gr.Row():
|
| 188 |
-
d_t_vid = gr.Video(
|
| 189 |
-
|
| 190 |
-
)
|
| 191 |
-
r_d_vid = gr.Video(
|
| 192 |
-
label="Recognize and Derender (Click to stop/play)", autoplay=True
|
| 193 |
-
)
|
| 194 |
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
|
| 195 |
|
| 196 |
with gr.Row():
|
|
|
|
| 39 |
"you",
|
| 40 |
"letter",
|
| 41 |
]
|
| 42 |
+
gif_base64_strings = {caption: get_base64_encoded_gif(f"gifs/{name}") for caption, name in zip(captions, gif_filenames)}
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
sketches = [
|
| 45 |
"bird.gif",
|
|
|
|
| 47 |
"coffee.gif",
|
| 48 |
"penguin.gif",
|
| 49 |
]
|
| 50 |
+
sketches_base64_strings = {name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches}
|
|
|
|
|
|
|
| 51 |
|
| 52 |
if not pre_generate:
|
| 53 |
+
# Check if the file already exists
|
| 54 |
+
if not (video_cache_dir / "gdrive_file.zip").exists():
|
| 55 |
+
print("Downloading pre-generated videos from Google Drive.")
|
| 56 |
+
# Download from Google Drive using gdown
|
| 57 |
+
gdown.download(
|
| 58 |
+
"https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA",
|
| 59 |
+
str(video_cache_dir / "gdrive_file.zip"),
|
| 60 |
+
quiet=False,
|
| 61 |
+
)
|
| 62 |
|
| 63 |
+
# Unzip the file to video_cache_dir
|
| 64 |
+
unzip_file(str(video_cache_dir / "gdrive_file.zip"))
|
| 65 |
+
else:
|
| 66 |
+
print("File already exists. Skipping download.")
|
| 67 |
else:
|
| 68 |
pregenerate_videos(video_cache_dir=video_cache_dir)
|
| 69 |
print("Videos cached.")
|
|
|
|
| 142 |
|
| 143 |
with gr.Blocks() as app:
|
| 144 |
gr.HTML(org_content)
|
| 145 |
+
gr.Markdown("# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write")
|
|
|
|
|
|
|
| 146 |
gr.HTML(
|
| 147 |
"""
|
| 148 |
+
<div style="display: flex; gap: 10px; justify-content: left;">
|
| 149 |
+
<a href="https://arxiv.org/abs/2402.05804">
|
| 150 |
+
<img src="https://img.shields.io/badge/📄_Read_the_Paper-4CAF50?style=for-the-badge&logo=arxiv&logoColor=white" alt="Read the Paper">
|
| 151 |
+
</a>
|
| 152 |
+
<a href="https://github.com/google-research/inksight">
|
| 153 |
+
<img src="https://img.shields.io/badge/View_on_GitHub-181717?style=for-the-badge&logo=github&logoColor=white" alt="View on GitHub">
|
| 154 |
+
</a>
|
| 155 |
+
<a href="https://research.google/blog/a-return-to-hand-written-notes-by-learning-to-read-write/">
|
| 156 |
+
<img src="https://img.shields.io/badge/🌐_Google_Research_Blog-333333?style=for-the-badge&logo=google&logoColor=white" alt="Google Research Blog">
|
| 157 |
+
</a>
|
| 158 |
+
<a href="https://charlieleee.github.io/publication/inksight/">
|
| 159 |
+
<img src="https://img.shields.io/badge/ℹ️_Info-FFA500?style=for-the-badge&logo=info&logoColor=white" alt="Info">
|
| 160 |
</a>
|
| 161 |
</div>
|
| 162 |
"""
|
|
|
|
| 169 |
"""
|
| 170 |
)
|
| 171 |
with gr.Row():
|
| 172 |
+
dataset = gr.Dropdown(["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM")
|
|
|
|
|
|
|
| 173 |
model = gr.Dropdown(
|
| 174 |
["Small-i", "Large-i", "Small-p"],
|
| 175 |
label="InkSight Model Variant",
|
|
|
|
| 183 |
# vanilla_img = gr.Image(label="Vanilla")
|
| 184 |
|
| 185 |
with gr.Row():
|
| 186 |
+
d_t_text = gr.Textbox(label="OCR recognition input to the model", interactive=False)
|
|
|
|
|
|
|
| 187 |
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
|
| 188 |
vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
|
| 189 |
with gr.Row():
|
| 190 |
+
d_t_vid = gr.Video(label="Derender with Text (Click to stop/play)", autoplay=True)
|
| 191 |
+
r_d_vid = gr.Video(label="Recognize and Derender (Click to stop/play)", autoplay=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
|
| 193 |
|
| 194 |
with gr.Row():
|
utils.py
CHANGED
|
@@ -32,6 +32,8 @@ def get_svg_content(svg_path):
|
|
| 32 |
|
| 33 |
|
| 34 |
def download_file(url, filename):
|
|
|
|
|
|
|
| 35 |
response = requests.get(url)
|
| 36 |
with open(filename, "wb") as f:
|
| 37 |
f.write(response.content)
|
|
@@ -84,22 +86,15 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
|
|
| 84 |
base_color = base_colors(len(ink.strokes) - 1 - i)
|
| 85 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
| 86 |
|
| 87 |
-
darker_color = colorsys.hsv_to_rgb(
|
| 88 |
-
|
| 89 |
-
)
|
| 90 |
-
colors = [
|
| 91 |
-
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x)))
|
| 92 |
-
for j in range(len(x))
|
| 93 |
-
]
|
| 94 |
|
| 95 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
| 96 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
| 97 |
|
| 98 |
lc = LineCollection(segments, colors=colors, linewidth=lw)
|
| 99 |
if with_path:
|
| 100 |
-
lc.set_path_effects(
|
| 101 |
-
[withStroke(linewidth=lw * 1.25, foreground=path_color)]
|
| 102 |
-
)
|
| 103 |
ax.add_collection(lc)
|
| 104 |
|
| 105 |
ax.set_xlim(0, 224)
|
|
@@ -107,9 +102,7 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
|
|
| 107 |
ax.invert_yaxis()
|
| 108 |
|
| 109 |
|
| 110 |
-
def plot_ink_to_video(
|
| 111 |
-
ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
|
| 112 |
-
):
|
| 113 |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
|
| 114 |
|
| 115 |
if input_image is not None:
|
|
@@ -143,26 +136,16 @@ def plot_ink_to_video(
|
|
| 143 |
|
| 144 |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
|
| 145 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
| 146 |
-
darker_color = colorsys.hsv_to_rgb(
|
| 147 |
-
|
| 148 |
-
)
|
| 149 |
-
visible_segments = (
|
| 150 |
-
segments[: frame - points_drawn]
|
| 151 |
-
if frame - points_drawn < len(segments)
|
| 152 |
-
else segments
|
| 153 |
-
)
|
| 154 |
colors = [
|
| 155 |
-
mcolors.to_rgba(
|
| 156 |
-
darker_color, alpha=1 - (0.5 * j / len(visible_segments))
|
| 157 |
-
)
|
| 158 |
for j in range(len(visible_segments))
|
| 159 |
]
|
| 160 |
|
| 161 |
if len(visible_segments) > 0:
|
| 162 |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
|
| 163 |
-
lc.set_path_effects(
|
| 164 |
-
[withStroke(linewidth=lw * 1.25, foreground=path_color)]
|
| 165 |
-
)
|
| 166 |
ax.add_collection(lc)
|
| 167 |
|
| 168 |
points_drawn += len(segments)
|
|
@@ -254,13 +237,9 @@ def pregenerate_videos(video_cache_dir):
|
|
| 254 |
if not os.path.exists(path):
|
| 255 |
continue
|
| 256 |
samples = os.listdir(path)
|
| 257 |
-
for name in tqdm(
|
| 258 |
-
samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"
|
| 259 |
-
):
|
| 260 |
example_id = name.strip(".png")
|
| 261 |
-
inkml_file = os.path.join(
|
| 262 |
-
inkml_path_base, mode, f"{example_id}.inkml"
|
| 263 |
-
)
|
| 264 |
if not os.path.exists(inkml_file):
|
| 265 |
continue
|
| 266 |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def download_file(url, filename):
|
| 35 |
+
if os.path.exists(filename):
|
| 36 |
+
return
|
| 37 |
response = requests.get(url)
|
| 38 |
with open(filename, "wb") as f:
|
| 39 |
f.write(response.content)
|
|
|
|
| 86 |
base_color = base_colors(len(ink.strokes) - 1 - i)
|
| 87 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
| 88 |
|
| 89 |
+
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
|
| 90 |
+
colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
| 93 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
| 94 |
|
| 95 |
lc = LineCollection(segments, colors=colors, linewidth=lw)
|
| 96 |
if with_path:
|
| 97 |
+
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
|
|
|
|
|
|
|
| 98 |
ax.add_collection(lc)
|
| 99 |
|
| 100 |
ax.set_xlim(0, 224)
|
|
|
|
| 102 |
ax.invert_yaxis()
|
| 103 |
|
| 104 |
|
| 105 |
+
def plot_ink_to_video(ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30):
|
|
|
|
|
|
|
| 106 |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
|
| 107 |
|
| 108 |
if input_image is not None:
|
|
|
|
| 136 |
|
| 137 |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
|
| 138 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
| 139 |
+
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
|
| 140 |
+
visible_segments = segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
colors = [
|
| 142 |
+
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(visible_segments)))
|
|
|
|
|
|
|
| 143 |
for j in range(len(visible_segments))
|
| 144 |
]
|
| 145 |
|
| 146 |
if len(visible_segments) > 0:
|
| 147 |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
|
| 148 |
+
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
|
|
|
|
|
|
|
| 149 |
ax.add_collection(lc)
|
| 150 |
|
| 151 |
points_drawn += len(segments)
|
|
|
|
| 237 |
if not os.path.exists(path):
|
| 238 |
continue
|
| 239 |
samples = os.listdir(path)
|
| 240 |
+
for name in tqdm(samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"):
|
|
|
|
|
|
|
| 241 |
example_id = name.strip(".png")
|
| 242 |
+
inkml_file = os.path.join(inkml_path_base, mode, f"{example_id}.inkml")
|
|
|
|
|
|
|
| 243 |
if not os.path.exists(inkml_file):
|
| 244 |
continue
|
| 245 |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
|