Skip to content

Commit

Permalink
feat: added capability to select model in vision methods (#156)
Browse files Browse the repository at this point in the history
* feat: added capability to select model in vision methods

* chore: added changeset

* fix: package lock file

* fix: updated return type for query method

* fix: updated query return type
  • Loading branch information
pushpam5 authored Oct 16, 2024
1 parent 5626697 commit 44f9a4a
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/cyan-crabs-poke.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"appwright": patch
---

feat: added capability to select model in vision methods
20 changes: 15 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "Apache-2.0",
"description": "E2E mobile app testing done right, with the Playwright test runner",
"dependencies": {
"@empiricalrun/llm": "^0.9.5",
"@empiricalrun/llm": "^0.9.9",
"@playwright/test": "^1.47.1",
"appium": "^2.6.0",
"appium-uiautomator2-driver": "^3.8.0",
Expand Down
19 changes: 16 additions & 3 deletions src/device/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
// @ts-ignore ts not able to identify the import is just an interface
import type { Client as WebDriverClient } from "webdriver";
import { Locator } from "../locator";
import { AppwrightLocator, Platform, TestInfoOptions } from "../types";
import {
AppwrightLocator,
ExtractType,
Platform,
TestInfoOptions,
} from "../types";
import { AppwrightVision, VisionProvider } from "../vision";
import { boxedStep, longestDeterministicGroup } from "../utils";
import { uploadImageToBS } from "../providers/browserstack/utils";
import { uploadImageToLambdaTest } from "../providers/lambdatest/utils";
import { z } from "zod";
import { LLMModel } from "@empiricalrun/llm";

export class Device {
constructor(
Expand Down Expand Up @@ -42,8 +49,14 @@ export class Device {
await this.vision().tap(prompt);
},

query: async (prompt: string): Promise<string> => {
return await this.vision().query(prompt);
query: async <T extends z.ZodType>(
prompt: string,
options?: {
responseFormat?: T;
model?: LLMModel;
},
): Promise<ExtractType<T>> => {
return await this.vision().query(prompt, options);
},
};

Expand Down
3 changes: 3 additions & 0 deletions src/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { Device } from "../device";
import { z } from "zod";

export type ExtractType<T> = T extends z.ZodType ? z.infer<T> : never;

export type WaitUntilOptions = {
/**
Expand Down
24 changes: 20 additions & 4 deletions src/vision/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import { getBoundingBox, query } from "@empiricalrun/llm/vision";
import { query } from "@empiricalrun/llm/vision";
import { getBoundingBox } from "@empiricalrun/llm/vision/bbox";
import fs from "fs";
// @ts-ignore ts not able to identify the import is just an interface
import { Client as WebDriverClient } from "webdriver";
import { Device } from "../device";
import test from "@playwright/test";
import { boxedStep } from "../utils";
import { z } from "zod";
import { LLMModel } from "@empiricalrun/llm";
import { ExtractType } from "../types";

export interface AppwrightVision {
/**
Expand All @@ -19,7 +23,13 @@ export interface AppwrightVision {
* @param prompt that defines the specific area or context from which text should be extracted.
* @returns
*/
query(prompt: string): Promise<string>;
query<T extends z.ZodType>(
prompt: string,
options?: {
responseFormat?: T;
model?: LLMModel;
},
): Promise<ExtractType<T>>;

/**
* Performs a tap action on the screen based on the provided prompt.
Expand All @@ -42,13 +52,19 @@ export class VisionProvider {
) {}

@boxedStep
async query(prompt: string): Promise<string> {
async query<T extends z.ZodType>(
prompt: string,
options?: {
responseFormat?: T;
model?: LLMModel;
},
): Promise<ExtractType<T>> {
test.skip(
!process.env.OPENAI_API_KEY,
"LLM vision based extract text is not enabled. Set the OPENAI_API_KEY environment variable to enable it",
);
const base64Screenshot = await this.webDriverClient.takeScreenshot();
return await query(base64Screenshot, prompt);
return await query(base64Screenshot, prompt, options);
}

@boxedStep
Expand Down

0 comments on commit 44f9a4a

Please sign in to comment.