Spaces:
Sleeping
Sleeping
lihong2303
commited on
Commit
·
ba1b871
1
Parent(s):
cbd2e9c
update
Browse files- .DS_Store +0 -0
- README copy.md +26 -0
- app.py +665 -0
- pangea.py +260 -0
- preprocess_ocl_data.py +207 -0
- requirements.txt +1 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
README copy.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Association
|
| 3 |
+
emoji: 💬
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.36.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: cc-by-4.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
- Association-Demo
|
| 14 |
+
|
| 15 |
+
- HMDB51
|
| 16 |
+
- OCL-3D
|
| 17 |
+
- data
|
| 18 |
+
- resources
|
| 19 |
+
- OCL_selected_test_affordance.pkl
|
| 20 |
+
- OCL_annot_test.pkl
|
| 21 |
+
|
| 22 |
+
- 26000: /home/lihong/workspace/pangea
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
import pickle
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Process:
|
| 9 |
+
gt_image= ""
|
| 10 |
+
gt_image_idx = (0, 0)
|
| 11 |
+
raw_image_path = ""
|
| 12 |
+
candidate_image1_path = ""
|
| 13 |
+
candidate_image2_path = ""
|
| 14 |
+
candidata_image1_idx = (0, 0)
|
| 15 |
+
candidate_image2_idx = (0, 0)
|
| 16 |
+
candidate_image1_group = "negative"
|
| 17 |
+
candidate_image2_group = "negative"
|
| 18 |
+
concept_choices = None
|
| 19 |
+
pkl_data = None
|
| 20 |
+
positive_cand = []
|
| 21 |
+
negative_cand = []
|
| 22 |
+
positive1_cand = []
|
| 23 |
+
positive2_cand = []
|
| 24 |
+
positive_common_cand = []
|
| 25 |
+
schedule = 0
|
| 26 |
+
idx_to_chain = {}
|
| 27 |
+
|
| 28 |
+
global process
|
| 29 |
+
process = Process()
|
| 30 |
+
|
| 31 |
+
def load_data_and_produce_list(dataset,exp_mode, concept_choices):
|
| 32 |
+
|
| 33 |
+
if dataset == "ocl_attribute":
|
| 34 |
+
#TODO
|
| 35 |
+
attr_name = ['wooden', 'metal', 'flying', 'ripe', 'fresh', 'natural', 'cooked', 'painted', 'rusty', 'furry']
|
| 36 |
+
attr2idx = {item:idx for idx,item in enumerate(attr_name)}
|
| 37 |
+
idx_2_attr = {value:key for key,value in attr2idx.items()}
|
| 38 |
+
pkl_path = "Data/OCL_data/OCL_selected_test_attribute_refined.pkl"
|
| 39 |
+
image_dir = "Data/OCL_data/data"
|
| 40 |
+
|
| 41 |
+
with open(pkl_path,"rb") as f:
|
| 42 |
+
data = pickle.load(f)
|
| 43 |
+
|
| 44 |
+
with open('Data/OCL_data/OCL_annot_test.pkl', "rb") as f:
|
| 45 |
+
process.pkl_data = pickle.load(f)
|
| 46 |
+
|
| 47 |
+
if exp_mode == "One concept":
|
| 48 |
+
process.positive_cand = data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]
|
| 49 |
+
process.negative_cand = data['negative_pkl']
|
| 50 |
+
|
| 51 |
+
else:
|
| 52 |
+
|
| 53 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("-")
|
| 54 |
+
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]]
|
| 55 |
+
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]]
|
| 56 |
+
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]]
|
| 57 |
+
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]]
|
| 58 |
+
process.negative_cand = data['negative_pkl']
|
| 59 |
+
|
| 60 |
+
elif dataset == "ocl_affordance":
|
| 61 |
+
aff_name = ['break', 'carry', 'clean','cut','push','sit','write']
|
| 62 |
+
aff2idx = {item:idx for idx,item in enumerate(aff_name)}
|
| 63 |
+
idx_2_attr = {value:key for key,value in aff2idx.items()}
|
| 64 |
+
pkl_path = "Data/OCL_data/OCL_selected_test_affordance_refined.pkl"
|
| 65 |
+
image_dir = "Data/OCL_data/data"
|
| 66 |
+
|
| 67 |
+
with open(pkl_path,"rb") as f:
|
| 68 |
+
data = pickle.load(f)
|
| 69 |
+
|
| 70 |
+
with open('Data/OCL_data/OCL_annot_test.pkl', "rb") as f:
|
| 71 |
+
process.pkl_data = pickle.load(f)
|
| 72 |
+
if exp_mode == "One concept":
|
| 73 |
+
process.positive_cand = data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]
|
| 74 |
+
process.negative_cand = data['negative_pkl']
|
| 75 |
+
else:
|
| 76 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("-")
|
| 77 |
+
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]]
|
| 78 |
+
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]]
|
| 79 |
+
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]]
|
| 80 |
+
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]]
|
| 81 |
+
process.negative_cand = data['negative_pkl']
|
| 82 |
+
elif dataset == "Pangea":
|
| 83 |
+
attr_name = ["hit-18.1","run-51.3.2","dress-41.1.1-1-1","drive-11.5","cooking-45.3","build-26.1","shake-22.3-2","cut-21.1-1"]
|
| 84 |
+
attr2idx = {item:idx for idx,item in enumerate(attr_name)}
|
| 85 |
+
idx_2_attr = {value:key for key,value in attr2idx.items()}
|
| 86 |
+
pkl_path = "Data/pangea/pangea_test_refined.pkl"
|
| 87 |
+
image_dir = "Data/pangea/pangea"
|
| 88 |
+
with open(pkl_path,"rb") as f:
|
| 89 |
+
data = pickle.load(f)
|
| 90 |
+
|
| 91 |
+
with open("Data/pangea/B123_test_KIN-FULL_with_node.pkl", "rb") as f:
|
| 92 |
+
process.pkl_data = pickle.load(f)
|
| 93 |
+
|
| 94 |
+
if exp_mode == "One concept":
|
| 95 |
+
process.positive_cand = data['selected_pkl'][process.idx_to_chain[concept_choices]]
|
| 96 |
+
process.negative_cand = data['negative_pkl']
|
| 97 |
+
else:
|
| 98 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("_")
|
| 99 |
+
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]]
|
| 100 |
+
process.positive1_cand = selected_paired_pkl[selected_concept_group[0]]
|
| 101 |
+
process.positive2_cand = selected_paired_pkl[selected_concept_group[1]]
|
| 102 |
+
process.positive_common_cand = selected_paired_pkl[process.idx_to_chain[concept_choices]]
|
| 103 |
+
process.negative_cand = data['negative_pkl']
|
| 104 |
+
|
| 105 |
+
elif dataset == "hmdb":
|
| 106 |
+
attr_name = ['brush_hair','clap', 'dive', 'shake_hands','hug' ,'sit','smoke','eat']
|
| 107 |
+
attr2idx = {key:item for key,item in enumerate(attr_name)}
|
| 108 |
+
image_dir = "Data/refined_HMDB"
|
| 109 |
+
pkl_path = "Data/refined_HMDB.pkl"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
with open(pkl_path,"rb") as f:
|
| 113 |
+
data = pickle.load(f)
|
| 114 |
+
|
| 115 |
+
if exp_mode == "One concept":
|
| 116 |
+
positive_cand = []
|
| 117 |
+
negative_cand = []
|
| 118 |
+
for each_data in data:
|
| 119 |
+
each_data['name'] = os.path.join(image_dir,each_data['name'])
|
| 120 |
+
if process.idx_to_chain[concept_choices] in each_data["label"]:
|
| 121 |
+
positive_cand.append(each_data)
|
| 122 |
+
else:
|
| 123 |
+
negative_cand.append(each_data)
|
| 124 |
+
|
| 125 |
+
if len(positive_cand) > 30 and len(negative_cand) > 100:
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
process.positive_cand = positive_cand
|
| 129 |
+
process.negative_cand = negative_cand
|
| 130 |
+
|
| 131 |
+
else:
|
| 132 |
+
|
| 133 |
+
negative_cand = []
|
| 134 |
+
positive1_cand = []
|
| 135 |
+
positive2_cand = []
|
| 136 |
+
positive_common_cand = []
|
| 137 |
+
|
| 138 |
+
for each_data in data:
|
| 139 |
+
each_data['name'] = os.path.join(image_dir,each_data['name'])
|
| 140 |
+
|
| 141 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("-")
|
| 142 |
+
|
| 143 |
+
if selected_concept_group[0] in each_data["name"] and selected_concept_group[1] in each_data["name"]:
|
| 144 |
+
positive_common_cand.append(each_data)
|
| 145 |
+
elif selected_concept_group[0] in each_data["name"]:
|
| 146 |
+
positive1_cand.append(each_data)
|
| 147 |
+
elif selected_concept_group[1] in each_data["name"]:
|
| 148 |
+
positive2_cand.append(each_data)
|
| 149 |
+
else:
|
| 150 |
+
if len(negative_cand) <= 100:
|
| 151 |
+
negative_cand.append(each_data)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
process.positive1_cand = positive1_cand
|
| 155 |
+
process.positive2_cand = positive2_cand
|
| 156 |
+
process.positive_common_cand = positive_common_cand
|
| 157 |
+
process.negative_cand = negative_cand
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
TARGET_SIZE = (200,200)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def load_images(dataset, raw_image_path, candidate_image1_path, candidate_image2_path):
|
| 164 |
+
if dataset == "ocl_attribute" or dataset == "ocl_affordance":
|
| 165 |
+
image_dir = "Data/OCL_data/data"
|
| 166 |
+
raw_data = process.pkl_data[raw_image_path[0]]
|
| 167 |
+
|
| 168 |
+
img_path = os.path.join(image_dir,raw_data["name"])
|
| 169 |
+
raw_image = Image.open(img_path).crop(raw_data['objects'][raw_image_path[1]]['box']).resize(TARGET_SIZE)
|
| 170 |
+
|
| 171 |
+
candidate_data1 = process.pkl_data[candidate_image1_path[0]]
|
| 172 |
+
cand1_img_path = os.path.join(image_dir,candidate_data1["name"])
|
| 173 |
+
candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE)
|
| 174 |
+
|
| 175 |
+
candidate_data2 = process.pkl_data[candidate_image2_path[0]]
|
| 176 |
+
cand2_img_path = os.path.join(image_dir,candidate_data2["name"])
|
| 177 |
+
candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE)
|
| 178 |
+
elif dataset == "Pangea":
|
| 179 |
+
mapping_dataset_directory = {'ActvityNet_hico_style_batch1':'ActivityNet_hico_batch1','charadesEgo_hico_style':'charadesego_frame', 'HAG_hico_style_new':'hag_frame','HACS_hico_style':'hacs_frame','kinetics_hico_style':'kinetics_dataset/k700-2020/train'}
|
| 180 |
+
image_dir = "Data/pangea/pangea"
|
| 181 |
+
raw_data = process.pkl_data[raw_image_path]
|
| 182 |
+
img_path = os.path.join(image_dir,mapping_dataset_directory[raw_data[0]], raw_data[1])
|
| 183 |
+
raw_image = Image.open(img_path).resize(TARGET_SIZE)
|
| 184 |
+
|
| 185 |
+
candidate_data1 = process.pkl_data[candidate_image1_path]
|
| 186 |
+
cand1_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data1[0]], candidate_data1[1])
|
| 187 |
+
candidate_image1 = Image.open(cand1_img_path).resize(TARGET_SIZE)
|
| 188 |
+
|
| 189 |
+
candidate_data2 = process.pkl_data[candidate_image2_path]
|
| 190 |
+
cand2_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data2[0]], candidate_data2[1])
|
| 191 |
+
candidate_image2 = Image.open(cand2_img_path).resize(TARGET_SIZE)
|
| 192 |
+
else:
|
| 193 |
+
raw_image = Image.open(raw_image_path['name']).resize(TARGET_SIZE)
|
| 194 |
+
candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE)
|
| 195 |
+
candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE)
|
| 196 |
+
|
| 197 |
+
return raw_image, candidate_image1, candidate_image2
|
| 198 |
+
|
| 199 |
+
def load_candidate_images(dataset, cand_image,candidate_image1_path,candidate_image2_path):
|
| 200 |
+
raw_image = cand_image
|
| 201 |
+
if dataset == "ocl_attribute" or dataset == "ocl_affordance":
|
| 202 |
+
image_dir = "Data/OCL_data/data"
|
| 203 |
+
candidate_data1 = process.pkl_data[candidate_image1_path[0]]
|
| 204 |
+
cand1_img_path = os.path.join(image_dir, candidate_data1["name"])
|
| 205 |
+
candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE)
|
| 206 |
+
|
| 207 |
+
candidate_data2 = process.pkl_data[candidate_image2_path[0]]
|
| 208 |
+
cand2_img_path = os.path.join(image_dir, candidate_data2["name"])
|
| 209 |
+
candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE)
|
| 210 |
+
elif dataset == "Pangea":
|
| 211 |
+
mapping_dataset_directory = {'ActvityNet_hico_style_batch1':'ActivityNet_hico_batch1','charadesEgo_hico_style':'charadesego_frame', 'HAG_hico_style_new':'hag_frame','HACS_hico_style':'hacs_frame','kinetics_hico_style':'kinetics_dataset/k700-2020/train'}
|
| 212 |
+
image_dir = "Data/pangea/pangea"
|
| 213 |
+
candidate_data1 = process.pkl_data[candidate_image1_path]
|
| 214 |
+
cand1_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data1[0]],candidate_data1[1])
|
| 215 |
+
candidate_image1 = Image.open(cand1_img_path).resize(TARGET_SIZE)
|
| 216 |
+
|
| 217 |
+
candidate_data2 = process.pkl_data[candidate_image2_path]
|
| 218 |
+
cand2_img_path = os.path.join(image_dir,mapping_dataset_directory[candidate_data2[0]],candidate_data2[1])
|
| 219 |
+
candidate_image2 = Image.open(cand2_img_path).resize(TARGET_SIZE)
|
| 220 |
+
else:
|
| 221 |
+
candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE)
|
| 222 |
+
candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE)
|
| 223 |
+
|
| 224 |
+
return raw_image,candidate_image1,candidate_image2
|
| 225 |
+
|
| 226 |
+
class InferenceDemo(object):
|
| 227 |
+
def __init__(self,args,dataset,exp_mode,concept_choices):
|
| 228 |
+
print("init success")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_concept_choices(dataset,exp_mode):
|
| 232 |
+
# if dataset == "ocl":
|
| 233 |
+
if dataset == "ocl_affordance":
|
| 234 |
+
if exp_mode == "One concept":
|
| 235 |
+
choices = [f"Chain_{i}" for i in range(8)]
|
| 236 |
+
else:
|
| 237 |
+
choices = [f"Chain_{i}" for i in range(4)]
|
| 238 |
+
elif dataset == "Pangea":
|
| 239 |
+
if exp_mode == "One concept":
|
| 240 |
+
choices = [f"Chain_{i}" for i in range(8)]
|
| 241 |
+
else:
|
| 242 |
+
choices = [f"Chain_{i}" for i in range(4)]
|
| 243 |
+
else:
|
| 244 |
+
if exp_mode == "One concept":
|
| 245 |
+
choices = [f"Chain_{i}" for i in range(8)]
|
| 246 |
+
else:
|
| 247 |
+
choices = [f"Chain_{i}" for i in range(4)]
|
| 248 |
+
|
| 249 |
+
return gr.update(choices=choices)
|
| 250 |
+
|
| 251 |
+
def load_images_and_concepts(dataset,exp_mode,concept_choices):
|
| 252 |
+
|
| 253 |
+
process.concept_choices = concept_choices
|
| 254 |
+
idx_2_chain = {}
|
| 255 |
+
if dataset == "ocl_attribute":
|
| 256 |
+
if exp_mode == "One concept":
|
| 257 |
+
concept = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"]
|
| 258 |
+
for idx in range(8):
|
| 259 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 260 |
+
else:
|
| 261 |
+
concept = ["furry-metal","fresh-cooked","natural-ripe","painted-rusty"]
|
| 262 |
+
for idx in range(4):
|
| 263 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 264 |
+
elif dataset == "ocl_affordance":
|
| 265 |
+
if exp_mode == "One concept":
|
| 266 |
+
concept = ['break', 'carry', 'clean','cut','open','push','sit','write']
|
| 267 |
+
for idx in range(8):
|
| 268 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 269 |
+
else:
|
| 270 |
+
concept = ['sit-write','push-carry','cut-clean','open-break']
|
| 271 |
+
for idx in range(4):
|
| 272 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 273 |
+
elif dataset == "Pangea":
|
| 274 |
+
if exp_mode == "One concept":
|
| 275 |
+
concept = ["hit-18.1","run-51.3.2","dress-41.1.1-1-1","drive-11.5","cooking-45.3","build-26.1","shake-22.3-2","cut-21.1-1"]
|
| 276 |
+
for idx in range(8):
|
| 277 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 278 |
+
else:
|
| 279 |
+
concept = ['run-51.3.2_hit-18.1', 'drive-11.5_dress-41.1.1-1-1', 'cooking-45.3_build-26.1','shake-22.3-2_cut-21.1-1']
|
| 280 |
+
for idx in range(4):
|
| 281 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 282 |
+
else:
|
| 283 |
+
if exp_mode == "One concept":
|
| 284 |
+
concept = ["brush_hair","dive","clap","hug","shake_hands","sit","smoke","eat"]
|
| 285 |
+
for idx in range(8):
|
| 286 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 287 |
+
else:
|
| 288 |
+
concept = ["brush_hair-dive","clap-hug","shake_hands-sit","smoke-eat"]
|
| 289 |
+
for idx in range(4):
|
| 290 |
+
idx_2_chain[f"Chain_{idx}"] = concept[idx]
|
| 291 |
+
process.idx_to_chain = idx_2_chain
|
| 292 |
+
|
| 293 |
+
load_data_and_produce_list(dataset,exp_mode,concept_choices)
|
| 294 |
+
|
| 295 |
+
if exp_mode == "One concept":
|
| 296 |
+
if random.random() < 0.5:
|
| 297 |
+
process.raw_image_path = random.choice(process.positive_cand)
|
| 298 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.positive_cand)
|
| 299 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 300 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive", "negative"
|
| 301 |
+
process.gt_image = "Image1"
|
| 302 |
+
process.gt_image_idx = process.candidata_image1_idx
|
| 303 |
+
else:
|
| 304 |
+
process.raw_image_path = random.choice(process.positive_cand)
|
| 305 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 306 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.positive_cand)
|
| 307 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive"
|
| 308 |
+
process.gt_image = "Image2"
|
| 309 |
+
process.gt_image_idx = process.candidate_image2_idx
|
| 310 |
+
else:
|
| 311 |
+
if random.random() < 0.5:
|
| 312 |
+
process.raw_image_path = random.choice(process.positive1_cand)
|
| 313 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.positive1_cand)
|
| 314 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 315 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive1", "negative"
|
| 316 |
+
process.gt_image = "Image1"
|
| 317 |
+
process.gt_image_idx = process.candidata_image1_idx
|
| 318 |
+
else:
|
| 319 |
+
process.raw_image_path = random.choice(process.positive1_cand)
|
| 320 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 321 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.positive1_cand)
|
| 322 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1"
|
| 323 |
+
process.gt_image = "Image2"
|
| 324 |
+
process.gt_image_idx = process.candidate_image2_idx
|
| 325 |
+
raw_image,candidate_image1,candidate_image2 = load_images(dataset, process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path)
|
| 326 |
+
if dataset == "Pangea":
|
| 327 |
+
concept = ["hit", "run", "dress", "drive", "cooking", "build", "shake", "cut"]
|
| 328 |
+
elif dataset == "ocl_attribute":
|
| 329 |
+
concept = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"]
|
| 330 |
+
elif dataset == "ocl_affordance":
|
| 331 |
+
concept = ['break', 'carry', 'clean','cut','open','push','sit','write']
|
| 332 |
+
return raw_image,candidate_image1,candidate_image2, str(concept)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def count_and_reload_images(dataset,exp_mode, select_input,show_result, steps,raw_image,candidate_image1,candidate_image2):
|
| 336 |
+
|
| 337 |
+
if select_input != None:
|
| 338 |
+
if select_input == process.gt_image or int(steps) < 6 or select_input == 'Uncertain':
|
| 339 |
+
if select_input == 'Uncertain':
|
| 340 |
+
if process.gt_image == 'Image1':
|
| 341 |
+
negative_sample = 'Image2'
|
| 342 |
+
else:
|
| 343 |
+
negative_sample = 'Image1'
|
| 344 |
+
filter_images(dataset, exp_mode, process.concept_choices, negative_sample)
|
| 345 |
+
if select_input == process.gt_image:
|
| 346 |
+
show_result = "Success!"
|
| 347 |
+
elif select_input == 'Uncertain':
|
| 348 |
+
show_result = 'Skip'
|
| 349 |
+
else:
|
| 350 |
+
show_result = "Error!"
|
| 351 |
+
|
| 352 |
+
if exp_mode == "One concept":
|
| 353 |
+
if process.gt_image == "Image1":
|
| 354 |
+
candidate_image = candidate_image1
|
| 355 |
+
else:
|
| 356 |
+
candidate_image = candidate_image2
|
| 357 |
+
if random.random() < 0.5:
|
| 358 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive_cand if x!=process.gt_image_idx])
|
| 359 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 360 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive", "negative"
|
| 361 |
+
process.gt_image = "Image1"
|
| 362 |
+
process.gt_image_idx = process.candidate_image1_idx
|
| 363 |
+
else:
|
| 364 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 365 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_cand if x!=process.gt_image_idx])
|
| 366 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive"
|
| 367 |
+
process.gt_image = "Image2"
|
| 368 |
+
process.gt_image_idx = process.candidate_image2_idx
|
| 369 |
+
|
| 370 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 371 |
+
else:
|
| 372 |
+
if process.gt_image == "Image1":
|
| 373 |
+
candidate_image = candidate_image1
|
| 374 |
+
else:
|
| 375 |
+
candidate_image = candidate_image2
|
| 376 |
+
|
| 377 |
+
if random.random() < 0.5:
|
| 378 |
+
if process.schedule < 3:
|
| 379 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx])
|
| 380 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 381 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive1", "negative"
|
| 382 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 383 |
+
process.schedule += 1
|
| 384 |
+
elif process.schedule == 3:
|
| 385 |
+
if len(process.positive_common_cand) != 0:
|
| 386 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx])
|
| 387 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 388 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive_com", "negative"
|
| 389 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 390 |
+
else:
|
| 391 |
+
process.raw_image_path = random.choice(process.positive2_cand)
|
| 392 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx])
|
| 393 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 394 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive2", "negative"
|
| 395 |
+
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path)
|
| 396 |
+
process.schedule += 1
|
| 397 |
+
elif process.schedule < 7:
|
| 398 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx])
|
| 399 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice(process.negative_cand)
|
| 400 |
+
process.candidate_image1_group, process.candidate_image2_group = "positive2", "negative"
|
| 401 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 402 |
+
process.schedule += 1
|
| 403 |
+
elif process.schedule == 7:
|
| 404 |
+
if len(process.positive_common_cand) != 0:
|
| 405 |
+
process.candidate_image1_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx])
|
| 406 |
+
process.candidate_image2_path = random.choice(process.negative_cand)
|
| 407 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 408 |
+
else:
|
| 409 |
+
process.raw_image_path = random.choice(process.positive1_cand)
|
| 410 |
+
process.candidate_image1_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx])
|
| 411 |
+
process.candidate_image2_path = random.choice(process.negative_cand)
|
| 412 |
+
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path)
|
| 413 |
+
process.schedule = 0
|
| 414 |
+
process.gt_image = "Image1"
|
| 415 |
+
process.gt_image_idx = process.candidate_image1_idx
|
| 416 |
+
else:
|
| 417 |
+
if process.schedule < 3:
|
| 418 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx])
|
| 419 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 420 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1"
|
| 421 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 422 |
+
process.schedule += 1
|
| 423 |
+
elif process.schedule == 3:
|
| 424 |
+
if len(process.positive_common_cand) != 0:
|
| 425 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx])
|
| 426 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 427 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive_com"
|
| 428 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 429 |
+
else:
|
| 430 |
+
process.raw_image_path = random.choice(process.positive2_cand)
|
| 431 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx])
|
| 432 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 433 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive2"
|
| 434 |
+
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path)
|
| 435 |
+
process.schedule += 1
|
| 436 |
+
elif process.schedule < 7:
|
| 437 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive2_cand if x!=process.gt_image_idx])
|
| 438 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 439 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive2"
|
| 440 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 441 |
+
process.schedule += 1
|
| 442 |
+
elif process.schedule == 7:
|
| 443 |
+
if len(process.positive_common_cand) != 0:
|
| 444 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive_common_cand if x!=process.gt_image_idx])
|
| 445 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 446 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive_com"
|
| 447 |
+
raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path)
|
| 448 |
+
else:
|
| 449 |
+
process.raw_image_path = random.choice(process.positive1_cand)
|
| 450 |
+
process.candidate_image2_idx = process.candidate_image2_path = random.choice([x for x in process.positive1_cand if x!=process.gt_image_idx])
|
| 451 |
+
process.candidate_image1_idx = process.candidate_image1_path = random.choice(process.negative_cand)
|
| 452 |
+
process.candidate_image1_group, process.candidate_image2_group = "negative", "positive1"
|
| 453 |
+
raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path)
|
| 454 |
+
process.schedule = 0
|
| 455 |
+
|
| 456 |
+
process.gt_image = "Image2"
|
| 457 |
+
process.gt_image_idx = process.candidate_image2_idx
|
| 458 |
+
|
| 459 |
+
if select_input != 'Uncertain':
|
| 460 |
+
steps = int(steps) + 1
|
| 461 |
+
select_input = None
|
| 462 |
+
|
| 463 |
+
else:
|
| 464 |
+
show_result = "Error, Please reset!"
|
| 465 |
+
process.gt_image = None
|
| 466 |
+
|
| 467 |
+
return select_input,show_result, steps,raw_image,candidate_image1,candidate_image2
|
| 468 |
+
|
| 469 |
+
def filter_images(dataset, exp_mode, concept_choices, image_filtered):
|
| 470 |
+
if image_filtered == None:
|
| 471 |
+
return None
|
| 472 |
+
if dataset == "ocl_attribute" or dataset == "ocl_affordance":
|
| 473 |
+
if dataset == "ocl_attribute":
|
| 474 |
+
pkl_path = "Data/OCL_data/OCL_selected_test_attribute_refined.pkl"
|
| 475 |
+
else:
|
| 476 |
+
pkl_path = "Data/OCL_data/OCL_selected_test_affordance_refined.pkl"
|
| 477 |
+
with open(pkl_path,"rb") as f:
|
| 478 |
+
data = pickle.load(f)
|
| 479 |
+
if exp_mode == "One concept":
|
| 480 |
+
if image_filtered == "Image1":
|
| 481 |
+
print(process.candidate_image1_idx)
|
| 482 |
+
if process.candidate_image1_group == "positive":
|
| 483 |
+
if process.candidate_image1_idx in data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]:
|
| 484 |
+
data['selected_individual_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx)
|
| 485 |
+
elif process.candidate_image1_group == "negative":
|
| 486 |
+
if process.candidate_image1_idx in data["negative_pkl"]:
|
| 487 |
+
data["negative_pkl"].remove(process.candidate_image1_idx)
|
| 488 |
+
else:
|
| 489 |
+
print('Error')
|
| 490 |
+
else:
|
| 491 |
+
print(process.candidate_image2_idx)
|
| 492 |
+
if process.candidate_image2_group == "positive":
|
| 493 |
+
if process.candidate_image2_idx in data['selected_individual_pkl'][process.idx_to_chain[concept_choices]]:
|
| 494 |
+
data['selected_individual_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx)
|
| 495 |
+
elif process.candidate_image2_group == "negative":
|
| 496 |
+
if process.candidate_image2_idx in data["negative_pkl"]:
|
| 497 |
+
data["negative_pkl"].remove(process.candidate_image2_idx)
|
| 498 |
+
else:
|
| 499 |
+
print('Error')
|
| 500 |
+
else:
|
| 501 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("_")
|
| 502 |
+
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]]
|
| 503 |
+
if image_filtered == "Image1":
|
| 504 |
+
print(process.candidate_image1_idx)
|
| 505 |
+
if process.candidate_image1_group == "positive1":
|
| 506 |
+
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[0]]:
|
| 507 |
+
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image1_idx)
|
| 508 |
+
elif process.candidate_image1_group == "positive2":
|
| 509 |
+
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[1]]:
|
| 510 |
+
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image1_idx)
|
| 511 |
+
elif process.candidate_image1_group == "positive_com":
|
| 512 |
+
if process.candidate_image1_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]:
|
| 513 |
+
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx)
|
| 514 |
+
elif process.candidate_image1_group == "negative":
|
| 515 |
+
if process.candidate_image1_idx in data["negative_pkl"]:
|
| 516 |
+
data["negative_pkl"].remove(process.candidate_image1_idx)
|
| 517 |
+
else:
|
| 518 |
+
print('Error')
|
| 519 |
+
else:
|
| 520 |
+
print(process.candidate_image2_idx)
|
| 521 |
+
if process.candidate_image2_group == "positive1":
|
| 522 |
+
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[0]]:
|
| 523 |
+
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image2_idx)
|
| 524 |
+
elif process.candidate_image2_group == "positive2":
|
| 525 |
+
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[1]]:
|
| 526 |
+
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image2_idx)
|
| 527 |
+
elif process.candidate_image2_group == "positive_com":
|
| 528 |
+
if process.candidate_image2_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]:
|
| 529 |
+
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx)
|
| 530 |
+
elif process.candidate_image2_group == "negative":
|
| 531 |
+
if process.candidate_image2_idx in data["negative_pkl"]:
|
| 532 |
+
data["negative_pkl"].remove(process.candidate_image2_idx)
|
| 533 |
+
else:
|
| 534 |
+
print('Error')
|
| 535 |
+
with open(pkl_path, "wb") as f:
|
| 536 |
+
pickle.dump(data, f)
|
| 537 |
+
|
| 538 |
+
elif dataset == "Pangea":
|
| 539 |
+
pkl_path = "Data/pangea/pangea_test_refined.pkl"
|
| 540 |
+
with open(pkl_path,"rb") as f:
|
| 541 |
+
data = pickle.load(f)
|
| 542 |
+
if exp_mode == "One concept":
|
| 543 |
+
if image_filtered == "Image1":
|
| 544 |
+
print(process.candidate_image1_idx)
|
| 545 |
+
if process.candidate_image1_group == "positive":
|
| 546 |
+
if process.candidate_image1_idx in data['selected_pkl'][process.idx_to_chain[concept_choices]]:
|
| 547 |
+
data['selected_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx)
|
| 548 |
+
elif process.candidate_image1_group == "negative":
|
| 549 |
+
if process.candidate_image1_idx in data["negative_pkl"]:
|
| 550 |
+
data["negative_pkl"].remove(process.candidate_image1_idx)
|
| 551 |
+
else:
|
| 552 |
+
print('Error')
|
| 553 |
+
else:
|
| 554 |
+
print(process.candidate_image2_idx)
|
| 555 |
+
if process.candidate_image2_group == "positive":
|
| 556 |
+
if process.candidate_image2_idx in data['selected_pkl'][process.idx_to_chain[concept_choices]]:
|
| 557 |
+
data['selected_pkl'][process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx)
|
| 558 |
+
elif process.candidate_image2_group == "negative":
|
| 559 |
+
if process.candidate_image2_idx in data["negative_pkl"]:
|
| 560 |
+
data["negative_pkl"].remove(process.candidate_image2_idx)
|
| 561 |
+
else:
|
| 562 |
+
print('Error')
|
| 563 |
+
else:
|
| 564 |
+
selected_concept_group = process.idx_to_chain[concept_choices].split("-")
|
| 565 |
+
selected_paired_pkl = data['selected_paired_pkl'][process.idx_to_chain[concept_choices]]
|
| 566 |
+
if image_filtered == "Image1":
|
| 567 |
+
print(process.candidate_image1_idx)
|
| 568 |
+
if process.candidate_image1_group == "positive1":
|
| 569 |
+
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[0]]:
|
| 570 |
+
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image1_idx)
|
| 571 |
+
elif process.candidate_image1_group == "positive2":
|
| 572 |
+
if process.candidate_image1_idx in selected_paired_pkl[selected_concept_group[1]]:
|
| 573 |
+
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image1_idx)
|
| 574 |
+
elif process.candidate_image1_group == "positive_com":
|
| 575 |
+
if process.candidate_image1_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]:
|
| 576 |
+
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image1_idx)
|
| 577 |
+
elif process.candidate_image1_group == "negative":
|
| 578 |
+
if process.candidate_image1_idx in data["negative_pkl"]:
|
| 579 |
+
data["negative_pkl"].remove(process.candidate_image1_idx)
|
| 580 |
+
else:
|
| 581 |
+
print('Error')
|
| 582 |
+
else:
|
| 583 |
+
print(process.candidate_image2_idx)
|
| 584 |
+
if process.candidate_image2_group == "positive1":
|
| 585 |
+
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[0]]:
|
| 586 |
+
selected_paired_pkl[selected_concept_group[0]].remove(process.candidate_image2_idx)
|
| 587 |
+
elif process.candidate_image2_group == "positive2":
|
| 588 |
+
if process.candidate_image2_idx in selected_paired_pkl[selected_concept_group[1]]:
|
| 589 |
+
selected_paired_pkl[selected_concept_group[1]].remove(process.candidate_image2_idx)
|
| 590 |
+
elif process.candidate_image2_group == "positive_com":
|
| 591 |
+
if process.candidate_image2_idx in selected_paired_pkl[process.idx_to_chain[concept_choices]]:
|
| 592 |
+
selected_paired_pkl[process.idx_to_chain[concept_choices]].remove(process.candidate_image2_idx)
|
| 593 |
+
elif process.candidate_image2_group == "negative":
|
| 594 |
+
if process.candidate_image2_idx in data["negative_pkl"]:
|
| 595 |
+
data["negative_pkl"].remove(process.candidate_image2_idx)
|
| 596 |
+
else:
|
| 597 |
+
print('Error')
|
| 598 |
+
with open(pkl_path, "wb") as f:
|
| 599 |
+
pickle.dump(data, f)
|
| 600 |
+
else:
|
| 601 |
+
print("Error")
|
| 602 |
+
|
| 603 |
+
return None
|
| 604 |
+
|
| 605 |
+
with gr.Blocks() as demo:
|
| 606 |
+
|
| 607 |
+
title_markdown = ("""
|
| 608 |
+
# MLLM Associstion
|
| 609 |
+
[[Paper]](https://mvig-rhos.com) [[Code]](https://github.com/lihong2303/MLLMs_Association)
|
| 610 |
+
""")
|
| 611 |
+
# 
|
| 612 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 613 |
+
gr.Markdown(title_markdown)
|
| 614 |
+
|
| 615 |
+
with gr.Row():
|
| 616 |
+
with gr.Column():
|
| 617 |
+
raw_image = gr.Image(label="Raw Image",interactive=False)
|
| 618 |
+
with gr.Column():
|
| 619 |
+
candidate_image1 = gr.Image(label="Candidate Image 1",interactive=False)
|
| 620 |
+
with gr.Column():
|
| 621 |
+
candidate_image2 = gr.Image(label="Candidate Image 2",interactive=False)
|
| 622 |
+
|
| 623 |
+
with gr.Row():
|
| 624 |
+
candidate_concepts = gr.Label(value="", label="Candidate Concepts")
|
| 625 |
+
filter_Images = gr.Radio(choices=["Image1", "Image2"],label="Filter low quality image")
|
| 626 |
+
|
| 627 |
+
with gr.Row():
|
| 628 |
+
dataset = gr.Dropdown(choices=["ocl_attribute","ocl_affordance","hmdb", "Pangea"],label="Select a dataset",interactive=True)
|
| 629 |
+
exp_mode = gr.Dropdown(choices=["One concept","Two concepts"],label="Select a test mode",interactive=True)
|
| 630 |
+
concept_choices = gr.Dropdown(choices=[],label = "Select the chain",interactive=True)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
with gr.Row():
|
| 634 |
+
select_input = gr.Radio(choices=["Image1","Image2","Uncertain"],label="Select candidate image")
|
| 635 |
+
steps = gr.Label(value="0",label="Steps")
|
| 636 |
+
show_result = gr.Label(value="",label="Selected Result")
|
| 637 |
+
# reset_button = gr.Button(text="Reset")
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
exp_mode.change(fn=get_concept_choices,inputs=[dataset,exp_mode],outputs=concept_choices)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
concept_choices.change(fn=load_images_and_concepts,
|
| 644 |
+
inputs=[dataset,exp_mode,concept_choices],
|
| 645 |
+
outputs=[raw_image,candidate_image1,candidate_image2, candidate_concepts])
|
| 646 |
+
filter_Images.change(fn=filter_images, inputs=[dataset, exp_mode, concept_choices, filter_Images], outputs=[filter_Images])
|
| 647 |
+
select_input.change(fn=count_and_reload_images,inputs=[dataset,exp_mode,select_input,show_result,steps,raw_image,candidate_image1,candidate_image2],outputs=[select_input,show_result,steps,raw_image,candidate_image1,candidate_image2])
|
| 648 |
+
|
| 649 |
+
demo.queue()
|
| 650 |
+
|
| 651 |
+
if __name__ == "__main__":
|
| 652 |
+
demo.launch()
|
| 653 |
+
# demo.launch(server_port=6126)
|
| 654 |
+
# import argparse
|
| 655 |
+
# argparser = argparse.ArgumentParser()
|
| 656 |
+
# argparser.add_argument("--server_name", default="0.0.0.0", type=str)
|
| 657 |
+
# argparser.add_argument("--port", default="6123", type=str)
|
| 658 |
+
# args = argparser.parse_args()
|
| 659 |
+
# try:
|
| 660 |
+
# demo.launch(server_name=args.server_name, server_port=int(args.port),share=False)
|
| 661 |
+
# except Exception as e:
|
| 662 |
+
# args.port=int(args.port)+1
|
| 663 |
+
# print(f"Port {args.port} is occupied, try port {args.port}")
|
| 664 |
+
# demo.launch(server_name=args.server_name, server_port=int(args.port),share=False)
|
| 665 |
+
|
pangea.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
mapping_dataset_directory = {'ActvityNet_hico_style_batch1':'ActivityNet_hico_batch1','charadesEgo_hico_style':'charadesego_frame', 'HAG_hico_style_new':'hag_frame','HACS_hico_style':'hacs_frame','kinetics_hico_style':'kinetics_dataset/k700-2020/train'}
|
| 9 |
+
|
| 10 |
+
train_pkl = "/home/lihong/chenyuanjie/Sandwich/Data/B123_train_KIN-FULL_with_node.pkl"
|
| 11 |
+
split_test_path = "/home/lihong/chenyuanjie/Sandwich/Data/B123_test_KIN-FULL_with_node.pkl"
|
| 12 |
+
|
| 13 |
+
with open(split_test_path, 'rb') as f:
|
| 14 |
+
data = pickle.load(f)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# split_test_pkl = []
|
| 18 |
+
|
| 19 |
+
# action_num_dict = {}
|
| 20 |
+
|
| 21 |
+
# for data_idx, data_item in enumerate(data):
|
| 22 |
+
# if data_item[0] in mapping_dataset_directory.keys():
|
| 23 |
+
# dataset = mapping_dataset_directory[data_item[0]]
|
| 24 |
+
# else:
|
| 25 |
+
# dataset = data_item[0]
|
| 26 |
+
|
| 27 |
+
# orig_label = data_item[2]
|
| 28 |
+
# node_labels = data_item[3]
|
| 29 |
+
|
| 30 |
+
# for nod_lab in node_labels:
|
| 31 |
+
# if nod_lab in action_num_dict.keys():
|
| 32 |
+
# action_num_dict[nod_lab] += 1
|
| 33 |
+
# else:
|
| 34 |
+
# action_num_dict[nod_lab] = 1
|
| 35 |
+
|
| 36 |
+
# if data_idx %1000 == 0:
|
| 37 |
+
# print(len(data),data_idx)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# current_action_num_dict = {}
|
| 41 |
+
# dataset_list = []
|
| 42 |
+
# for data_idx, data_item in enumerate(data):
|
| 43 |
+
# if data_item[0] in mapping_dataset_directory.keys():
|
| 44 |
+
# dataset = mapping_dataset_directory[data_item[0]]
|
| 45 |
+
# else:
|
| 46 |
+
# dataset = data_item[0]
|
| 47 |
+
# image_path = '/data/xiaoqian/Images/' + dataset + '/' + data_item[1]
|
| 48 |
+
# if not os.path.isfile(image_path):
|
| 49 |
+
# if not dataset in dataset_list:
|
| 50 |
+
# dataset_list.append(dataset)
|
| 51 |
+
# continue
|
| 52 |
+
|
| 53 |
+
# orig_label = data_item[2]
|
| 54 |
+
# node_labels = data_item[3]
|
| 55 |
+
# flag = False
|
| 56 |
+
# for nod_lab in node_labels:
|
| 57 |
+
# if nod_lab in current_action_num_dict.keys():
|
| 58 |
+
# if current_action_num_dict[nod_lab] < action_num_dict[nod_lab] * 0.1 and not flag:
|
| 59 |
+
# split_test_pkl.append(data_item)
|
| 60 |
+
# flag = True
|
| 61 |
+
# current_action_num_dict[nod_lab] += 1
|
| 62 |
+
# else:
|
| 63 |
+
# split_test_pkl.append(data_item)
|
| 64 |
+
# flag = True
|
| 65 |
+
# current_action_num_dict[nod_lab] = 1
|
| 66 |
+
|
| 67 |
+
# if data_idx % 1000 == 0:
|
| 68 |
+
# print(len(data),data_idx)
|
| 69 |
+
|
| 70 |
+
# print(action_num_dict, current_action_num_dict)
|
| 71 |
+
# with open(split_test_path, 'wb') as f:
|
| 72 |
+
# pickle.dump(split_test_pkl, f)
|
| 73 |
+
# exit()
|
| 74 |
+
|
| 75 |
+
## mapping node to idx
|
| 76 |
+
mapping_node_index = pickle.load(open("/home/lihong/chenyuanjie/Sandwich/Data/mapping_node_index.pkl", "rb"))
|
| 77 |
+
verbnet_topology = pickle.load(open("/home/lihong/chenyuanjie/Sandwich/Data/verbnet_topology_898.pkl", "rb"))
|
| 78 |
+
|
| 79 |
+
Father2Son, objects = verbnet_topology["Father2Son"], verbnet_topology["objects"]
|
| 80 |
+
|
| 81 |
+
objects = np.array(objects)
|
| 82 |
+
objects_290 = objects[mapping_node_index]
|
| 83 |
+
|
| 84 |
+
object_to_idx = {obj: idx for idx, obj in enumerate(objects_290)}
|
| 85 |
+
|
| 86 |
+
# filtered_objects = [obj.split("-")[0] for obj in objects_290]
|
| 87 |
+
|
| 88 |
+
selected_list = ["hit", "push","run","dress","drive","cook","throw","build","shake","cut"]
|
| 89 |
+
true_selected_list = ["hit-18.1","push-12","run-51.3.2","dress-41.1.1-1-1","drive-11.5","cooking-45.3","throw-17.1-1","build-26.1","shake-22.3-2","cut-21.1-1"]
|
| 90 |
+
true_selected_list_id = [object_to_idx[node] for node in true_selected_list]
|
| 91 |
+
true_selected_paired_list = ['run-51.3.2_hit-18.1', 'drive-11.5_dress-41.1.1-1-1', 'cooking-45.3_build-26.1','shake-22.3-2_cut-21.1-1'] #,'throw-17.1-1_push-12'
|
| 92 |
+
true_label = {}
|
| 93 |
+
|
| 94 |
+
# {'hit-18.1cut-21.1-1': 86808, 'hit-18.1drive-11.5': 14935, 'hit-18.1run-51.3.2': 34237}
|
| 95 |
+
# {'run-51.3.2run-51.3.2': 341324, 'run-51.3.2hit-18.1': 34237, 'run-51.3.2cut-21.1-1': 20389}
|
| 96 |
+
# {'dress-41.1.1-1-1dress-41.1.1-1-1': 470063, 'dress-41.1.1-1-1run-51.3.2': 63862, 'dress-41.1.1-1-1cut-21.1-1': 47727, 'dress-41.1.1-1-1drive-11.5': 24965, 'dress-41.1.1-1-1hit-18.1': 23118, 'dress-41.1.1-1-1push-12': 11982, 'dress-41.1.1-1-1cooking-45.3': 469, 'dress-41.1.1-1-1build-26.1': 306}
|
| 97 |
+
# {'drive-11.5drive-11.5': 238175, 'drive-11.5build-26.1': 15223, 'drive-11.5hit-18.1': 14935, 'drive-11.5cut-21.1-1': 30031, 'drive-11.5dress-41.1.1-1-1': 24965}
|
| 98 |
+
# {'cooking-45.3cooking-45.3': 68577, 'cooking-45.3build-26.1': 37668, 'cooking-45.3cut-21.1-1': 15072, 'cooking-45.3dress-41.1.1-1-1': 469}
|
| 99 |
+
# {'throw-17.1-1throw-17.1-1': 394887, 'throw-17.1-1hit-18.1': 92553, 'throw-17.1-1drive-11.5': 30348, 'throw-17.1-1dress-41.1.1-1-1': 97911, 'throw-17.1-1run-51.3.2': 30097, 'throw-17.1-1push-12': 20854, 'throw-17.1-1cut-21.1-1': 20714}
|
| 100 |
+
# {'build-26.1cooking-45.3': 37668, 'build-26.1build-26.1': 95743, 'build-26.1drive-11.5': 15223, 'build-26.1cut-21.1-1': 23454, 'build-26.1shake-22.3-2': 23015, 'build-26.1dress-41.1.1-1-1': 306}
|
| 101 |
+
# {'shake-22.3-2build-26.1': 23015, 'shake-22.3-2shake-22.3-2': 23015, 'shake-22.3-2cut-21.1-1': 13005}
|
| 102 |
+
# {'cut-21.1-1cut-21.1-1': 553752, 'cut-21.1-1hit-18.1': 86808, 'cut-21.1-1drive-11.5': 30031, 'cut-21.1-1dress-41.1.1-1-1': 47727, 'cut-21.1-1build-26.1': 23454, 'cut-21.1-1cooking-45.3': 15072, 'cut-21.1-1run-51.3.2': 20389, 'cut-21.1-1throw-17.1-1': 20714, 'cut-21.1-1shake-22.3-2': 13005}
|
| 103 |
+
|
| 104 |
+
selected_pkl = {}
|
| 105 |
+
selected_paired_pkl = {}
|
| 106 |
+
pangea_pkl = {}
|
| 107 |
+
|
| 108 |
+
negative_pkl = []
|
| 109 |
+
|
| 110 |
+
dataset_list = []
|
| 111 |
+
num_images = 0
|
| 112 |
+
|
| 113 |
+
for data_idx, data_item in enumerate(data):
|
| 114 |
+
save_flag = False
|
| 115 |
+
if data_item[0] in mapping_dataset_directory.keys():
|
| 116 |
+
dataset = mapping_dataset_directory[data_item[0]]
|
| 117 |
+
else:
|
| 118 |
+
dataset = data_item[0]
|
| 119 |
+
|
| 120 |
+
image_path = '/data/xiaoqian/Images/' + dataset + '/' + data_item[1]
|
| 121 |
+
|
| 122 |
+
if not os.path.isfile(image_path):
|
| 123 |
+
if not dataset in dataset_list:
|
| 124 |
+
dataset_list.append(dataset)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
print(data_item)
|
| 128 |
+
exit()
|
| 129 |
+
orig_label = data_item[2]
|
| 130 |
+
node_labels = data_item[3]
|
| 131 |
+
node_labels_id = [object_to_idx[node] for node in node_labels]
|
| 132 |
+
|
| 133 |
+
co_objects = list(set(node_labels_id).intersection(set(true_selected_list_id)))
|
| 134 |
+
|
| 135 |
+
if len(co_objects) > 0:
|
| 136 |
+
|
| 137 |
+
for sel_paired_objects in true_selected_list:
|
| 138 |
+
if object_to_idx[sel_paired_objects] in co_objects:
|
| 139 |
+
|
| 140 |
+
if sel_paired_objects not in selected_pkl.keys():
|
| 141 |
+
selected_pkl[sel_paired_objects] = [data_idx]
|
| 142 |
+
save_flag = True
|
| 143 |
+
else:
|
| 144 |
+
if len(selected_pkl[sel_paired_objects]) < 2000:
|
| 145 |
+
save_flag = True
|
| 146 |
+
selected_pkl[sel_paired_objects].append(data_idx)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
for sel_paired_objects in true_selected_paired_list:
|
| 150 |
+
sel_obj1, sel_obj2 = sel_paired_objects.split("_")
|
| 151 |
+
|
| 152 |
+
if object_to_idx[sel_obj1] in co_objects and object_to_idx[sel_obj2] in co_objects:
|
| 153 |
+
|
| 154 |
+
if sel_paired_objects not in selected_paired_pkl.keys():
|
| 155 |
+
selected_paired_pkl[sel_paired_objects] = {}
|
| 156 |
+
if sel_paired_objects not in selected_paired_pkl[sel_paired_objects].keys():
|
| 157 |
+
save_flag = True
|
| 158 |
+
selected_paired_pkl[sel_paired_objects][sel_paired_objects] = [data_idx]
|
| 159 |
+
else:
|
| 160 |
+
if len(selected_paired_pkl[sel_paired_objects][sel_paired_objects]) < 2000:
|
| 161 |
+
save_flag = True
|
| 162 |
+
selected_paired_pkl[sel_paired_objects][sel_paired_objects].append(data_idx)
|
| 163 |
+
elif object_to_idx[sel_obj1] in co_objects:
|
| 164 |
+
|
| 165 |
+
if sel_paired_objects not in selected_paired_pkl.keys():
|
| 166 |
+
selected_paired_pkl[sel_paired_objects] = {}
|
| 167 |
+
if sel_obj1 not in selected_paired_pkl[sel_paired_objects].keys():
|
| 168 |
+
save_flag = True
|
| 169 |
+
selected_paired_pkl[sel_paired_objects][sel_obj1] = [data_idx]
|
| 170 |
+
else:
|
| 171 |
+
if len(selected_paired_pkl[sel_paired_objects][sel_obj1]) < 2000:
|
| 172 |
+
save_flag = True
|
| 173 |
+
selected_paired_pkl[sel_paired_objects][sel_obj1].append(data_idx)
|
| 174 |
+
elif object_to_idx[sel_obj2] in co_objects:
|
| 175 |
+
|
| 176 |
+
if sel_paired_objects not in selected_paired_pkl.keys():
|
| 177 |
+
selected_paired_pkl[sel_paired_objects] = {}
|
| 178 |
+
if sel_obj2 not in selected_paired_pkl[sel_paired_objects].keys():
|
| 179 |
+
save_flag = True
|
| 180 |
+
selected_paired_pkl[sel_paired_objects][sel_obj2] = [data_idx]
|
| 181 |
+
else:
|
| 182 |
+
if len(selected_paired_pkl[sel_paired_objects][sel_obj2]) < 2000:
|
| 183 |
+
save_flag = True
|
| 184 |
+
selected_paired_pkl[sel_paired_objects][sel_obj2].append(data_idx)
|
| 185 |
+
else:
|
| 186 |
+
if len(negative_pkl) < 3000:
|
| 187 |
+
|
| 188 |
+
neg_flag = False
|
| 189 |
+
for sel_list in selected_list:
|
| 190 |
+
for nod_lab in node_labels:
|
| 191 |
+
if sel_list in nod_lab:
|
| 192 |
+
neg_flag = True
|
| 193 |
+
break
|
| 194 |
+
if neg_flag:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if not neg_flag:
|
| 198 |
+
save_flag = True
|
| 199 |
+
negative_pkl.append(data_idx)
|
| 200 |
+
if save_flag:
|
| 201 |
+
num_images += 1
|
| 202 |
+
if not os.path.exists(os.path.dirname(os.path.join("/home/lihong/workspace/pangea/pangea", dataset, data_item[1]))):
|
| 203 |
+
os.makedirs(os.path.dirname(os.path.join("/home/lihong/workspace/pangea/pangea", dataset, data_item[1])))
|
| 204 |
+
shutil.copy(image_path, os.path.join("/home/lihong/workspace/pangea/pangea", dataset, data_item[1]))
|
| 205 |
+
|
| 206 |
+
if data_idx % 1000 == 0:
|
| 207 |
+
print(len(data),data_idx)
|
| 208 |
+
|
| 209 |
+
for name in selected_pkl.keys():
|
| 210 |
+
print(f"selected {name} affordance has {len(selected_pkl[name])} objects")
|
| 211 |
+
|
| 212 |
+
for name in selected_paired_pkl.keys():
|
| 213 |
+
for sub_name in selected_paired_pkl[name].keys():
|
| 214 |
+
print(f"selected {name} paired actions {sub_name} has {len(selected_paired_pkl[name][sub_name])} objects")
|
| 215 |
+
|
| 216 |
+
print("negative_pkl has {} objects".format(len(negative_pkl)))
|
| 217 |
+
print("num_images has {} objects".format(num_images))
|
| 218 |
+
|
| 219 |
+
pangea_pkl["selected_pkl"] = selected_pkl
|
| 220 |
+
pangea_pkl["selected_paired_pkl"] = selected_paired_pkl
|
| 221 |
+
pangea_pkl["negative_pkl"] = negative_pkl
|
| 222 |
+
|
| 223 |
+
with open(os.path.join("/home/lihong/workspace/pangea","pangea_test.pkl"),"wb") as fp:
|
| 224 |
+
pickle.dump(pangea_pkl,fp)
|
| 225 |
+
|
| 226 |
+
print(dataset_list)
|
| 227 |
+
|
| 228 |
+
exit()
|
| 229 |
+
selected_images = {}
|
| 230 |
+
|
| 231 |
+
save_data = []
|
| 232 |
+
|
| 233 |
+
for data_idx, data_item in enumerate(data):
|
| 234 |
+
if data_item[0] in mapping_dataset_directory.keys():
|
| 235 |
+
dataset = mapping_dataset_directory[data_item[0]]
|
| 236 |
+
else:
|
| 237 |
+
dataset = data_item[0]
|
| 238 |
+
image_path = '/data/xiaoqian/Images/' + dataset + '/' + data_item[1]
|
| 239 |
+
|
| 240 |
+
if not os.path.isfile(image_path):
|
| 241 |
+
if not dataset in dataset_list:
|
| 242 |
+
dataset_list.append(dataset)
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
orig_label = data_item[2]
|
| 246 |
+
node_labels = data_item[3]
|
| 247 |
+
if true_selected_list[9] in node_labels:
|
| 248 |
+
for i in range(len(true_selected_list)):
|
| 249 |
+
if true_selected_list[i] in node_labels:
|
| 250 |
+
if true_selected_list[9]+ true_selected_list[i] in selected_images.keys():
|
| 251 |
+
selected_images[true_selected_list[9]+ true_selected_list[i]] += 1
|
| 252 |
+
else:
|
| 253 |
+
selected_images[true_selected_list[9]+ true_selected_list[i]] = 1
|
| 254 |
+
|
| 255 |
+
if data_idx %1000 == 0:
|
| 256 |
+
print(data_idx)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
print(selected_images)
|
| 260 |
+
|
preprocess_ocl_data.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_pickle_and_assign_split(pkl_dir,split):
|
| 10 |
+
pkl_path = os.path.join(pkl_dir, f"OCL_annot_{split}.pkl")
|
| 11 |
+
|
| 12 |
+
with open(pkl_path, 'rb') as fp:
|
| 13 |
+
pkl = pickle.load(fp)
|
| 14 |
+
for x in pkl:
|
| 15 |
+
x['split'] = split
|
| 16 |
+
return pkl
|
| 17 |
+
def load_class_json(name):
|
| 18 |
+
with open(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data/resources",f"OCL_class_{name}.json"),"r") as fp:
|
| 19 |
+
return json.load(fp)
|
| 20 |
+
|
| 21 |
+
pkl_dir = "/home/lixingyu/workspace_lxy2/Data/OCL_data/data/resources"
|
| 22 |
+
|
| 23 |
+
attrs = load_class_json("attribute")
|
| 24 |
+
aff_dict = load_class_json("affordance")
|
| 25 |
+
|
| 26 |
+
aff = []
|
| 27 |
+
for aff_item in aff_dict:
|
| 28 |
+
# aff.append(aff_item["word"])
|
| 29 |
+
if aff_item["word"][0] not in aff:
|
| 30 |
+
aff.append(aff_item["word"][0])
|
| 31 |
+
else:
|
| 32 |
+
if len(aff_item["word"]) > 1:
|
| 33 |
+
random_aff = random.choice(aff_item["word"])
|
| 34 |
+
while random_aff in aff and random_aff == aff_item["word"][0]:
|
| 35 |
+
random_aff = random.choice(aff_item["word"])
|
| 36 |
+
aff.append(random_aff)
|
| 37 |
+
else:
|
| 38 |
+
if len(aff_dict[aff.index(aff_item["word"][0])]["word"]) > 1:
|
| 39 |
+
random_aff = random.choice(aff_dict[aff.index(aff_item["word"][0])]["word"])
|
| 40 |
+
while random_aff in aff and random_aff == aff_item["word"][0]:
|
| 41 |
+
random_aff = random.choice(aff_dict[aff.index(aff_item["word"][0])]["word"])
|
| 42 |
+
|
| 43 |
+
aff[aff.index(aff_item["word"][0])] = random_aff
|
| 44 |
+
aff.append(aff_item["word"][0])
|
| 45 |
+
else:
|
| 46 |
+
aff.append(aff_item["word"][0] + '_1')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
print(len(attrs))
|
| 50 |
+
exit()
|
| 51 |
+
|
| 52 |
+
attrs_2_idx = {attr_item:idx for idx,attr_item in enumerate(attrs)}
|
| 53 |
+
|
| 54 |
+
aff_2_idx = {aff_item:idx for idx,aff_item in enumerate(aff)}
|
| 55 |
+
|
| 56 |
+
selected_affordance = ['break', 'carry', 'clean','close','cut','eat','open','push','sit','write']
|
| 57 |
+
selected_paired_affs = ['sit-write','push-carry','cut-clean','open-break', 'cut-close']
|
| 58 |
+
selected_aff_id = [aff_2_idx[aff_item] for aff_item in selected_affordance]
|
| 59 |
+
ocl_test_aff_pkl = {}
|
| 60 |
+
selected_pkl = {}
|
| 61 |
+
selected_paired_pkl = {}
|
| 62 |
+
negative_aff = []
|
| 63 |
+
|
| 64 |
+
for aff_idx, aff_item in enumerate(aff_dict):
|
| 65 |
+
|
| 66 |
+
aff_flag = False
|
| 67 |
+
for sel_aff in selected_affordance:
|
| 68 |
+
if sel_aff in aff_item["word"]:
|
| 69 |
+
aff_flag = True
|
| 70 |
+
break
|
| 71 |
+
if not aff_flag:
|
| 72 |
+
negative_aff.append(aff[aff_idx])
|
| 73 |
+
|
| 74 |
+
negative_aff_id = [aff_2_idx[aff_item] for aff_item in negative_aff]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
negative_pkl = []
|
| 78 |
+
for data_idx,each_item in enumerate(test_pkl):
|
| 79 |
+
if os.path.exists(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data",each_item["name"])):
|
| 80 |
+
for obj_idx,each_object in enumerate(each_item['objects']):
|
| 81 |
+
if (abs(each_object['box'][0] - each_object['box'][2])) * (abs(each_object['box'][1] - each_object['box'][3])) > 50000:
|
| 82 |
+
if len(list(set(each_object['aff']).intersection(set(selected_aff_id)))) > 0:
|
| 83 |
+
for each_sel_aff in selected_affordance:
|
| 84 |
+
if aff_2_idx[each_sel_aff] in each_object['aff']:
|
| 85 |
+
if each_sel_aff not in selected_pkl.keys():
|
| 86 |
+
selected_pkl[each_sel_aff] = [[data_idx, obj_idx]]
|
| 87 |
+
else:
|
| 88 |
+
selected_pkl[each_sel_aff].append([data_idx, obj_idx])
|
| 89 |
+
|
| 90 |
+
for each_paired_sel_aff in selected_paired_affs:
|
| 91 |
+
aff1,aff2 = each_paired_sel_aff.split("-")
|
| 92 |
+
if aff_2_idx[aff1] in each_object['aff'] and aff_2_idx[aff2] in each_object['aff']:
|
| 93 |
+
if each_paired_sel_aff not in selected_paired_pkl.keys():
|
| 94 |
+
selected_paired_pkl[each_paired_sel_aff] = {}
|
| 95 |
+
if each_paired_sel_aff not in selected_paired_pkl[each_paired_sel_aff].keys():
|
| 96 |
+
selected_paired_pkl[each_paired_sel_aff][each_paired_sel_aff] = [[data_idx, obj_idx]]
|
| 97 |
+
else:
|
| 98 |
+
selected_paired_pkl[each_paired_sel_aff][each_paired_sel_aff].append([data_idx, obj_idx])
|
| 99 |
+
elif aff_2_idx[aff1] in each_object['aff']:
|
| 100 |
+
if each_paired_sel_aff not in selected_paired_pkl.keys():
|
| 101 |
+
selected_paired_pkl[each_paired_sel_aff] = {}
|
| 102 |
+
if aff1 not in selected_paired_pkl[each_paired_sel_aff].keys():
|
| 103 |
+
selected_paired_pkl[each_paired_sel_aff][aff1] = [[data_idx, obj_idx]]
|
| 104 |
+
else:
|
| 105 |
+
selected_paired_pkl[each_paired_sel_aff][aff1].append([data_idx, obj_idx])
|
| 106 |
+
elif aff_2_idx[aff2] in each_object['aff']:
|
| 107 |
+
if each_paired_sel_aff not in selected_paired_pkl.keys():
|
| 108 |
+
selected_paired_pkl[each_paired_sel_aff] = {}
|
| 109 |
+
if aff2 not in selected_paired_pkl[each_paired_sel_aff].keys():
|
| 110 |
+
selected_paired_pkl[each_paired_sel_aff][aff2] = [[data_idx, obj_idx]]
|
| 111 |
+
else:
|
| 112 |
+
selected_paired_pkl[each_paired_sel_aff][aff2].append([data_idx, obj_idx])
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
if len(list(set(each_object['aff']).intersection(set(negative_aff_id)))) == len(each_object['aff']):
|
| 116 |
+
negative_pkl.append([data_idx,obj_idx])
|
| 117 |
+
if not os.path.exists(os.path.dirname(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data/saved_test_data",each_item["name"]))):
|
| 118 |
+
os.makedirs(os.path.dirname(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data/saved_test_data",each_item["name"])))
|
| 119 |
+
shutil.copy(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data",each_item["name"]),os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data/saved_test_data",each_item["name"]))
|
| 120 |
+
|
| 121 |
+
print("negative_pkl has {} objects".format(len(negative_pkl)))
|
| 122 |
+
|
| 123 |
+
for name in selected_pkl.keys():
|
| 124 |
+
print(f"selected {name} affordance has {len(selected_pkl[name])} objects")
|
| 125 |
+
|
| 126 |
+
for name in selected_paired_pkl.keys():
|
| 127 |
+
for sub_name in selected_paired_pkl[name].keys():
|
| 128 |
+
print(f"selected {name} paired affordance {sub_name} has {len(selected_paired_pkl[name][sub_name])} objects")
|
| 129 |
+
|
| 130 |
+
ocl_test_aff_pkl["selected_individual_pkl"] = selected_pkl
|
| 131 |
+
ocl_test_aff_pkl["selected_paired_pkl"] = selected_paired_pkl
|
| 132 |
+
ocl_test_aff_pkl["negative_pkl"] = negative_pkl
|
| 133 |
+
|
| 134 |
+
with open(os.path.join(pkl_dir,"OCL_selected_test_affordance.pkl"),"wb") as fp:
|
| 135 |
+
pickle.dump(ocl_test_aff_pkl,fp)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
selected_attrs = ['wooden', 'metal', 'flying', 'ripe', 'fresh', 'natural', 'cooked', 'painted', 'rusty', 'furry']
|
| 139 |
+
selected_attrs_id = [attrs_2_idx[attr_item] for attr_item in selected_attrs]
|
| 140 |
+
selected_paired_attrs = ["furry-metal","fresh-cooked","natural-ripe","painted-rusty"]
|
| 141 |
+
|
| 142 |
+
selected_affs = []
|
| 143 |
+
selected_paired_affs = []
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
ocl_test_attr_pkl = {}
|
| 147 |
+
|
| 148 |
+
selected_pkl = {}
|
| 149 |
+
selected_paired_pkl = {}
|
| 150 |
+
|
| 151 |
+
negative_pkl = []
|
| 152 |
+
for data_idx,each_item in enumerate(test_pkl):
|
| 153 |
+
if os.path.exists(os.path.join("/home/lixingyu/workspace_lxy2/Data/OCL_data/data",each_item["name"])):
|
| 154 |
+
for obj_idx,each_object in enumerate(each_item['objects']):
|
| 155 |
+
if (abs(each_object['box'][0] - each_object['box'][2])) * (abs(each_object['box'][1] - each_object['box'][3])) > 50000:
|
| 156 |
+
if len(list(set(each_object['attr']).intersection(set(selected_attrs_id)))) > 0:
|
| 157 |
+
for each_sel_attr in selected_attrs:
|
| 158 |
+
if attrs_2_idx[each_sel_attr] in each_object['attr']:
|
| 159 |
+
if each_sel_attr not in selected_pkl.keys():
|
| 160 |
+
selected_pkl[each_sel_attr] = [[data_idx, obj_idx]]
|
| 161 |
+
else:
|
| 162 |
+
selected_pkl[each_sel_attr].append([data_idx, obj_idx])
|
| 163 |
+
|
| 164 |
+
for each_paired_sel_attr in selected_paired_attrs:
|
| 165 |
+
attr1,attr2 = each_paired_sel_attr.split("-")
|
| 166 |
+
if attrs_2_idx[attr1] in each_object['attr'] and attrs_2_idx[attr2] in each_object['attr']:
|
| 167 |
+
if each_paired_sel_attr not in selected_paired_pkl.keys():
|
| 168 |
+
selected_paired_pkl[each_paired_sel_attr] = {}
|
| 169 |
+
if each_paired_sel_attr not in selected_paired_pkl[each_paired_sel_attr].keys():
|
| 170 |
+
selected_paired_pkl[each_paired_sel_attr][each_paired_sel_attr] = [[data_idx, obj_idx]]
|
| 171 |
+
else:
|
| 172 |
+
selected_paired_pkl[each_paired_sel_attr][each_paired_sel_attr].append([data_idx, obj_idx])
|
| 173 |
+
elif attrs_2_idx[attr1] in each_object['attr']:
|
| 174 |
+
if each_paired_sel_attr not in selected_paired_pkl.keys():
|
| 175 |
+
selected_paired_pkl[each_paired_sel_attr] = {}
|
| 176 |
+
if attr1 not in selected_paired_pkl[each_paired_sel_attr].keys():
|
| 177 |
+
selected_paired_pkl[each_paired_sel_attr][attr1] = [[data_idx, obj_idx]]
|
| 178 |
+
else:
|
| 179 |
+
selected_paired_pkl[each_paired_sel_attr][attr1].append([data_idx, obj_idx])
|
| 180 |
+
elif attrs_2_idx[attr2] in each_object['attr']:
|
| 181 |
+
if each_paired_sel_attr not in selected_paired_pkl.keys():
|
| 182 |
+
selected_paired_pkl[each_paired_sel_attr] = {}
|
| 183 |
+
if attr2 not in selected_paired_pkl[each_paired_sel_attr].keys():
|
| 184 |
+
selected_paired_pkl[each_paired_sel_attr][attr2] = [[data_idx, obj_idx]]
|
| 185 |
+
else:
|
| 186 |
+
selected_paired_pkl[each_paired_sel_attr][attr2].append([data_idx, obj_idx])
|
| 187 |
+
|
| 188 |
+
else:
|
| 189 |
+
negative_pkl.append([data_idx,obj_idx])
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
print("finshed!!!")
|
| 193 |
+
print("negative_pkl has {} objects".format(len(negative_pkl)))
|
| 194 |
+
for name in selected_pkl.keys():
|
| 195 |
+
print(f"selected {name} attribute has {len(selected_pkl[name])} objects")
|
| 196 |
+
|
| 197 |
+
for name in selected_paired_pkl.keys():
|
| 198 |
+
for sub_name in selected_paired_pkl[name].keys():
|
| 199 |
+
print(f"selected {name} paired attribute {sub_name} has {len(selected_paired_pkl[name][sub_name])} objects")
|
| 200 |
+
|
| 201 |
+
ocl_test_attr_pkl["selected_individual_pkl"] = selected_pkl
|
| 202 |
+
ocl_test_attr_pkl["selected_paired_pkl"] = selected_paired_pkl
|
| 203 |
+
ocl_test_attr_pkl["negative_pkl"] = negative_pkl
|
| 204 |
+
with open(os.path.join(pkl_dir,"OCL_selected_test_attribute.pkl"),"wb") as fp:
|
| 205 |
+
pickle.dump(ocl_test_attr_pkl,fp)
|
| 206 |
+
|
| 207 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub==0.22.2
|