| | import torch |
| | import xml.etree.ElementTree as ET |
| | from PIL import Image |
| | import numpy as np |
| | from pathlib import Path |
| | import json |
| |
|
| | from train import SimpleTransformer, flat_corners_from_mockup |
| |
|
| | |
| | |
| | |
| | def order_points_clockwise(pts): |
| | pts = np.array(pts, dtype="float32") |
| | y_sorted = pts[np.argsort(pts[:, 1]), :] |
| |
|
| | top_two = y_sorted[:2, :] |
| | bottom_two = y_sorted[2:, :] |
| |
|
| | if top_two[0][0] < top_two[1][0]: |
| | tl, tr = top_two |
| | else: |
| | tr, tl = top_two |
| |
|
| | if bottom_two[0][0] < bottom_two[1][0]: |
| | bl, br = bottom_two |
| | else: |
| | br, bl = bottom_two |
| |
|
| | return np.array([tl, tr, br, bl], dtype="float32") |
| |
|
| | |
| | |
| | |
| | def save_prediction_xml(pred_pts, out_path, img_w, img_h): |
| | ordered = order_points_clockwise(pred_pts) |
| | TL, TR, BR, BL = ordered |
| |
|
| | root = ET.Element("visualization", version="1.0") |
| | ET.SubElement(root, "effects", surfacecolor="", iswood="0") |
| | ET.SubElement(root, "background", |
| | width=str(img_w), height=str(img_h), |
| | color1="#C4CDE4", color2="", color3="") |
| |
|
| | transforms_node = ET.SubElement(root, "transforms") |
| | transform = ET.SubElement(transforms_node, "transform", |
| | type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0", |
| | rotationX="0.0", rotationY="0.0", rotationZ="0.0", |
| | name="Region", posCode="REGION", posName="Region", |
| | posDef="0", techCode="EMBF03", techName="Embroidery Fixed", |
| | techDef="0", areaWidth="100", areaHeight="100", |
| | maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100") |
| |
|
| | pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL} |
| | for ptype, (x, y) in pts.items(): |
| | ET.SubElement(transform, "point", |
| | type=ptype, x=str(float(x)), y=str(float(y)), |
| | z="0.0", warp="0", warpShift="0") |
| |
|
| | overlays = ET.SubElement(root, "overlays") |
| | overlay = ET.SubElement(overlays, "overlay") |
| | for (x, y) in ordered: |
| | ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0") |
| |
|
| | ET.SubElement(root, "ruler", |
| | startX=str(TL[0]), startY=str(TL[1]), |
| | stopX=str(BR[0]), stopY=str(BR[1]), value="100") |
| |
|
| | tree = ET.ElementTree(root) |
| | tree.write(out_path, encoding="utf-8", xml_declaration=True) |
| |
|
| |
|
| | |
| | |
| | |
| | def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | pers_img = Image.open(pers_img_path).convert("RGB") |
| | orig_w, orig_h = pers_img.size |
| |
|
| | |
| | _, flat_norm = flat_corners_from_mockup(mockup_json) |
| | flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device) |
| |
|
| | |
| | model = SimpleTransformer().to(device) |
| | state = torch.load(model_ckpt, map_location=device, weights_only=False) |
| | if "model_state" in state: |
| | model.load_state_dict(state["model_state"]) |
| | else: |
| | model.load_state_dict(state) |
| | model.eval() |
| |
|
| | |
| | with torch.no_grad(): |
| | pred = model(flat_in) |
| | pred = pred.view(4, 2).cpu().numpy() |
| |
|
| | |
| | pred_px = pred.copy() |
| | pred_px[:, 0] *= orig_w |
| | pred_px[:, 1] *= orig_h |
| |
|
| | |
| | save_prediction_xml(pred_px, out_path, orig_w, orig_h) |
| | print(f"Saved prediction -> {out_path}") |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json" |
| | pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg" |
| | model_ckpt = "Transformer/transformer_model.pth" |
| | predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml") |