| | import pytest |
| | from openai import OpenAI |
| | from utils import * |
| |
|
| | server = ServerPreset.bert_bge_small() |
| |
|
| | EPSILON = 1e-3 |
| |
|
| | @pytest.fixture(scope="module", autouse=True) |
| | def create_server(): |
| | global server |
| | server = ServerPreset.bert_bge_small() |
| |
|
| |
|
| | def test_embedding_single(): |
| | global server |
| | server.start() |
| | res = server.make_request("POST", "/embeddings", data={ |
| | "input": "I believe the meaning of life is", |
| | }) |
| | assert res.status_code == 200 |
| | assert len(res.body['data']) == 1 |
| | assert 'embedding' in res.body['data'][0] |
| | assert len(res.body['data'][0]['embedding']) > 1 |
| |
|
| | |
| | assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON |
| |
|
| |
|
| | def test_embedding_multiple(): |
| | global server |
| | server.start() |
| | res = server.make_request("POST", "/embeddings", data={ |
| | "input": [ |
| | "I believe the meaning of life is", |
| | "Write a joke about AI from a very long prompt which will not be truncated", |
| | "This is a test", |
| | "This is another test", |
| | ], |
| | }) |
| | assert res.status_code == 200 |
| | assert len(res.body['data']) == 4 |
| | for d in res.body['data']: |
| | assert 'embedding' in d |
| | assert len(d['embedding']) > 1 |
| |
|
| |
|
| | def test_embedding_openai_library_single(): |
| | global server |
| | server.start() |
| | client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") |
| | res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") |
| | assert len(res.data) == 1 |
| | assert len(res.data[0].embedding) > 1 |
| |
|
| |
|
| | def test_embedding_openai_library_multiple(): |
| | global server |
| | server.start() |
| | client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") |
| | res = client.embeddings.create(model="text-embedding-3-small", input=[ |
| | "I believe the meaning of life is", |
| | "Write a joke about AI from a very long prompt which will not be truncated", |
| | "This is a test", |
| | "This is another test", |
| | ]) |
| | assert len(res.data) == 4 |
| | for d in res.data: |
| | assert len(d.embedding) > 1 |
| |
|
| |
|
| | def test_embedding_error_prompt_too_long(): |
| | global server |
| | server.start() |
| | res = server.make_request("POST", "/embeddings", data={ |
| | "input": "This is a test " * 512, |
| | }) |
| | assert res.status_code != 200 |
| | assert "too large" in res.body["error"]["message"] |
| |
|
| |
|
| | def test_same_prompt_give_same_result(): |
| | server.start() |
| | res = server.make_request("POST", "/embeddings", data={ |
| | "input": [ |
| | "I believe the meaning of life is", |
| | "I believe the meaning of life is", |
| | "I believe the meaning of life is", |
| | "I believe the meaning of life is", |
| | "I believe the meaning of life is", |
| | ], |
| | }) |
| | assert res.status_code == 200 |
| | assert len(res.body['data']) == 5 |
| | for i in range(1, len(res.body['data'])): |
| | v0 = res.body['data'][0]['embedding'] |
| | vi = res.body['data'][i]['embedding'] |
| | for x, y in zip(v0, vi): |
| | assert abs(x - y) < EPSILON |
| |
|