diff --git a/src/definition-provider/definition-provider-for-schema.test.ts b/src/definition-provider/definition-provider-for-schema.test.ts index 12cecdf..026852d 100644 --- a/src/definition-provider/definition-provider-for-schema.test.ts +++ b/src/definition-provider/definition-provider-for-schema.test.ts @@ -1,6 +1,8 @@ +import { resolve } from 'path' import { Position, Range } from '../diff' import { trimSpaces } from '../util/trim-spaces' import { provideDefinitionForSchema } from './definition-provider-for-schema' +import { readFileSync } from 'fs-extra' const schema = ` type User { @@ -32,9 +34,15 @@ used as primitive type """ type Tweet { tweetId: ID! + status: TweetStatus mentions: [User!] } +enum TweetStatus { + DRAFT + ACTIVE +} + type Query { me: User user(id: ID!): User @@ -49,7 +57,6 @@ type Mutation { type Subscription { onUserChange(id: ID!): User } - ` function getAt(schema: string, range: Range | null) { @@ -58,12 +65,31 @@ function getAt(schema: string, range: Range | null) { return lines.join('\n') } +async function process(position: Position) { + const foundPosition = await provideDefinitionForSchema( + schema, + 'schema.gql', + position, + resolve(__dirname, 'test', '*.resolver.ts'), + resolve(__dirname, 'test', '*.model.ts'), + resolve(__dirname, 'test', '*.enum.ts'), + ) + if (!foundPosition) return '' + const source = foundPosition.path.startsWith('schema.gql') + ? schema + : readFileSync(foundPosition.path, 'utf-8') + return [ + `# ${[foundPosition.path.replace(__dirname + '/', ''), foundPosition.range.start.line, foundPosition.range.start.character].join(':')}`, + getAt(source, foundPosition.range), + ].join('\n') +} + describe('provideDefinitionFromSchema', () => { test('should provide return address type', async () => { - const range = provideDefinitionForSchema(schema, new Position(4, 16)) - const output = getAt(schema, range) + const output = await process(new Position(4, 16)) expect(output).toEqual( trimSpaces(` + # schema.gql:13:0 type Address { id: String address: String @@ -73,10 +99,10 @@ describe('provideDefinitionFromSchema', () => { }) test('should provide user status', async () => { - const range = provideDefinitionForSchema(schema, new Position(5, 16)) - const output = getAt(schema, range) + const output = await process(new Position(5, 16)) expect(output).toEqual( trimSpaces(` + # schema.gql:8:0 enum UserStatus { ACTIVE, DELETED @@ -85,14 +111,72 @@ describe('provideDefinitionFromSchema', () => { }) test('should provide tweet type', async () => { - const range = provideDefinitionForSchema(schema, new Position(36, 22)) - const output = getAt(schema, range) + const output = await process(new Position(42, 22)) expect(output).toEqual( trimSpaces(` + # schema.gql:28:0 type Tweet { tweetId: ID! + status: TweetStatus mentions: [User!] }`), ) }) + + test('should provide tweet status', async () => { + const output = await process(new Position(30, 22)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:34:0 + enum TweetStatus { + DRAFT + ACTIVE + }`), + ) + }) + + test('should provide definition tweet type in typescript', async () => { + const output = await process(new Position(42, 6)) + expect(output).toEqual( + trimSpaces(` + # test/tweet.resolver.ts:6:2 + tweet() {`), + ) + }) + + test('should provide definition user.name in typescript', async () => { + const output = await process(new Position(3, 5)) + expect(output).toEqual( + trimSpaces(` + # test/user.model.ts:6:2 + name?: string`), + ) + }) + + test('should provide definition user id in typescript', async () => { + const output = await process(new Position(2, 3)) + expect(output).toEqual( + trimSpaces(` + # test/base.model.ts:5:2 + id!: string`), + ) + }) + + test('should provide definition of enum in typescript', async () => { + const output = await process(new Position(34, 12)) + expect(output).toEqual( + trimSpaces(` + # test/tweet-status.enum.ts:2:12 + export enum TweetStatus {`), + ) + }) + + test('should provide definition of enum member in typescript', async () => { + const output = await process(new Position(35, 6)) + expect(output).toEqual( + trimSpaces(` + # test/tweet-status.enum.ts:3:2 + DRAFT = 'DRAFT',`), + ) + }) }) diff --git a/src/definition-provider/definition-provider-for-schema.ts b/src/definition-provider/definition-provider-for-schema.ts index 9eac683..6a09734 100644 --- a/src/definition-provider/definition-provider-for-schema.ts +++ b/src/definition-provider/definition-provider-for-schema.ts @@ -1,55 +1,358 @@ +import { globStream } from 'fast-glob' import * as gql from 'graphql' -import { Position } from '../diff' -import { getGQLNodeRange, getGQLNodeRangeWithoutDescription, makeQueryParsable } from '../gql' -import { isPositionWithInRange } from '../position/is-position-within-range' +import path from 'path' +import ts from 'typescript' +import { Location, Position, Range } from '../diff' +import { getGQLNodeRangeWithoutDescription, makeQueryParsable } from '../gql' +import { isInRange } from '../position/is-position-within-range' +import { getDecorator, hasDecorator, readTSFile } from '../ts' +import { SelectedField } from './selected-field.type' -function isInRange(node: gql.ASTNode, position: Position, offset?: Position) { - const nodeRange = getGQLNodeRange(node, offset) - return isPositionWithInRange(position, nodeRange, true) +function getPositionOfMethod( + node: + | ts.EnumDeclaration + | ts.EnumMember + | ts.ClassDeclaration + | ts.MethodDeclaration + | ts.PropertyDeclaration, + sourceFile: ts.SourceFile, +) { + const start = node.name?.getStart() + const end = node.name?.getEnd() + if (!start || !end) return + const startPosition = sourceFile.getLineAndCharacterOfPosition(start) + const endPosition = sourceFile.getLineAndCharacterOfPosition(end) + return new Location( + sourceFile.fileName, + new Range( + new Position(startPosition.line, startPosition.character), + new Position(endPosition.line, endPosition.character), + ), + ) } -export function provideDefinitionForSchema(source: string, position: Position) { +function parseByDecoratorAndName( + member: ts.ClassElement, + sourceFile: ts.SourceFile, + decoratorName: string, + fieldName: string, +) { + if ( + hasDecorator(member, decoratorName) && + (ts.isMethodDeclaration(member) || ts.isPropertyDeclaration(member)) && + ts.isIdentifier(member.name) && + member.name.getText() === fieldName + ) { + return getPositionOfMethod(member, sourceFile) + } +} + +function isResolvingAType(decorator: ts.Decorator, name: string) { + if ( + ts.isCallExpression(decorator.expression) && + !!decorator.expression.arguments.length && + ts.isArrowFunction(decorator.expression.arguments[0]) && + ts.isIdentifier(decorator.expression.arguments[0].body) && + decorator.expression.arguments[0].body.getText() === name + ) { + return true + } +} + +async function processResolvers( + selectedField: SelectedField, + resolverPattern: string, +): Promise { + if (!selectedField || !resolverPattern) return + const stream = globStream(resolverPattern, { onlyFiles: true, ignore: ['**/node_modules/**'] }) + for await (const file of stream) { + const sourceFile = readTSFile(file as string) + for (const classDeclaration of sourceFile.statements.filter(statement => + ts.isClassDeclaration(statement), + )) { + if (!ts.isClassDeclaration(classDeclaration)) continue + const resolverDecorator = getDecorator(classDeclaration, 'Resolver') + if (resolverDecorator && isResolvingAType(resolverDecorator, selectedField.parent)) { + for (const member of classDeclaration.members) { + const location = parseByDecoratorAndName( + member, + sourceFile, + 'ResolveField', + selectedField.name, + ) + if (location) return location + } + } else if (resolverDecorator) { + for (const member of classDeclaration.members) { + const location = parseByDecoratorAndName( + member, + sourceFile, + selectedField.parent, + selectedField.name, + ) + if (location) return location + } + } + } + } +} + +function getParent(classDeclaration: ts.ClassDeclaration, sourceFile: ts.SourceFile) { + if (!classDeclaration?.heritageClauses?.length) return + if ( + ts.isHeritageClause(classDeclaration?.heritageClauses[0]) && + ts.isExpressionWithTypeArguments(classDeclaration?.heritageClauses[0].types?.[0]) && + ts.isIdentifier(classDeclaration?.heritageClauses[0].types?.[0].expression) + ) { + const parentName = classDeclaration?.heritageClauses[0].types?.[0].expression?.getText() + const importDeclaration = sourceFile.statements.find( + statement => + ts.isImportDeclaration(statement) && + ts.isStringLiteral(statement.moduleSpecifier) && + statement.importClause?.namedBindings && + ts.isNamedImports(statement.importClause?.namedBindings) && + statement.importClause.namedBindings.elements + .map(el => el.name.getText()) + .includes(parentName), + ) + if ( + importDeclaration && + ts.isImportDeclaration(importDeclaration) && + ts.isStringLiteral(importDeclaration.moduleSpecifier) + ) { + const parentSourceFile = readTSFile( + `${path.resolve(path.dirname(sourceFile.fileName), importDeclaration.moduleSpecifier.text)}.ts`, + ) + for (const classDeclaration of parentSourceFile.statements.filter(statement => + ts.isClassDeclaration(statement), + )) { + if ( + (hasDecorator(classDeclaration, 'ObjectType') || + hasDecorator(classDeclaration, 'InputType')) && + ts.isClassDeclaration(classDeclaration) && + classDeclaration.name?.getText() === parentName + ) { + return { name: parentName, classDeclaration, sourceFile: parentSourceFile } + } + } + } + } +} + +function processObjectType( + classDeclaration: ts.ClassDeclaration, + sourceFile: ts.SourceFile, + selectedField: SelectedField, +): Location | undefined { + if (!ts.isClassDeclaration(classDeclaration)) return + if ( + (hasDecorator(classDeclaration, 'ObjectType') || hasDecorator(classDeclaration, 'InputType')) && + classDeclaration.name?.getText() === selectedField.parent + ) { + for (const member of classDeclaration.members) { + const location = parseByDecoratorAndName(member, sourceFile, 'Field', selectedField.name) + if (location) return location + } + const parent = getParent(classDeclaration, sourceFile) + if (parent) { + if (ts.isClassDeclaration(parent.classDeclaration)) { + const parentLocation = processObjectType(parent.classDeclaration, parent.sourceFile, { + ...selectedField, + parent: parent.name, + }) + if (parentLocation) return parentLocation + } + } + } +} + +async function processModels( + selectedField: SelectedField, + modelPattern: string, +): Promise { + if (!selectedField || !modelPattern) return + const stream = globStream(modelPattern, { onlyFiles: true, ignore: ['**/node_modules/**'] }) + for await (const file of stream) { + const sourceFile = readTSFile(file as string) + for (const classDeclaration of sourceFile.statements.filter(statement => + ts.isClassDeclaration(statement), + )) { + if (ts.isClassDeclaration(classDeclaration)) { + const location = processObjectType(classDeclaration, sourceFile, selectedField) + if (location) return location + } + } + } +} + +async function processEnum( + selectedEnum: SelectedField, + enumPattern: string, +): Promise { + if (!selectedEnum || !enumPattern) return + const stream = globStream(enumPattern, { onlyFiles: true, ignore: ['**/node_modules/**'] }) + for await (const file of stream) { + const sourceFile = readTSFile(file as string) + for (const enumDeclaration of sourceFile.statements) { + if (ts.isEnumDeclaration(enumDeclaration)) { + if (enumDeclaration?.name?.getText() === selectedEnum.parent) { + for (const value of enumDeclaration.members) { + if (value?.name?.getText() === selectedEnum.name) { + return getPositionOfMethod(value, sourceFile) + } + } + } + } + } + } +} + +async function processEnumForName( + name: string, + enumPattern: string, +): Promise { + if (!name || !enumPattern) return + const stream = globStream(enumPattern, { onlyFiles: true, ignore: ['**/node_modules/**'] }) + for await (const file of stream) { + const sourceFile = readTSFile(file as string) + for (const enumDeclaration of sourceFile.statements) { + if (ts.isEnumDeclaration(enumDeclaration)) { + if (enumDeclaration?.name?.getText() === name) { + return getPositionOfMethod(enumDeclaration, sourceFile) + } + } + } + } +} + +function processFromSchema(type: string, document: gql.DocumentNode, schemaLocation: string) { + let targetNode: gql.ASTNode | undefined + if (!type) return + const processNode = (node: gql.TypeDefinitionNode) => { + if (node.name.value !== type) return + targetNode = node + return gql.BREAK + } + gql.visit(document, { + EnumTypeDefinition(node) { + return processNode(node) + }, + ScalarTypeDefinition(node) { + return processNode(node) + }, + ObjectTypeDefinition(node) { + return processNode(node) + }, + InputObjectTypeDefinition(node) { + return processNode(node) + }, + UnionTypeDefinition(node) { + return processNode(node) + }, + InterfaceTypeDefinition(node) { + return processNode(node) + }, + }) + if (!targetNode) return + return new Location(schemaLocation, getGQLNodeRangeWithoutDescription(targetNode)) +} + +async function processModelForName( + name: string, + modelPattern: string, +): Promise { + if (!name || !modelPattern) return + const stream = globStream(modelPattern, { onlyFiles: true, ignore: ['**/node_modules/**'] }) + for await (const file of stream) { + const sourceFile = readTSFile(file as string) + for (const classDeclaration of sourceFile.statements) { + if (ts.isClassDeclaration(classDeclaration) && classDeclaration.name?.getText() === name) { + const location = getPositionOfMethod(classDeclaration, sourceFile) + console.log(`here location=${location}`) + if (location) return location + } + } + } +} + +export async function provideDefinitionForSchema( + source: string, + schemaLocation: string, + position: Position, + resolverPattern: string, + modelPattern: string, + enumPattern: string, +): Promise { try { const fixed = makeQueryParsable(source) const document = gql.parse(fixed) - let selectedName: string | undefined - let targetNode: gql.ASTNode | undefined + let type: string | undefined + let modelName: string | undefined + let selectedField: SelectedField | undefined + let enumName: string | undefined + let selectedEnum: SelectedField | undefined + const processFields = (node: gql.TypeDefinitionNode) => { + if (!isInRange(node, position)) return + if (isInRange(node.name, position)) { + modelName = node.name.value + return gql.BREAK + } + switch (node.kind) { + case gql.Kind.OBJECT_TYPE_DEFINITION: + case gql.Kind.INPUT_OBJECT_TYPE_DEFINITION: + case gql.Kind.INTERFACE_TYPE_DEFINITION: + for (const field of node.fields ?? []) { + if (isInRange(field.name, position)) { + selectedField = { parent: node.name.value, name: field.name.value } + return gql.BREAK + } + } + } + } gql.visit(document, { NamedType(node) { if (!isInRange(node, position)) return - selectedName = node.name.value - }, - }) - const processNode = (node: gql.TypeDefinitionNode) => { - if (node.name.value !== selectedName) return - targetNode = node - return gql.BREAK - } - if (!selectedName) return null - gql.visit(document, { - EnumTypeDefinition(node) { - return processNode(node) - }, - ScalarTypeDefinition(node) { - return processNode(node) + type = node.name.value + return gql.BREAK }, ObjectTypeDefinition(node) { - return processNode(node) + return processFields(node) }, InputObjectTypeDefinition(node) { - return processNode(node) - }, - UnionTypeDefinition(node) { - return processNode(node) + return processFields(node) }, InterfaceTypeDefinition(node) { - return processNode(node) + return processFields(node) + }, + EnumTypeDefinition(node) { + if (!isInRange(node, position)) return + if (isInRange(node.name, position)) { + enumName = node.name.value + return gql.BREAK + } + for (const field of node.values ?? []) { + if (isInRange(field.name, position)) { + selectedEnum = { parent: node.name.value, name: field.name.value } + return gql.BREAK + } + } }, }) - if (!targetNode) return null - return getGQLNodeRangeWithoutDescription(targetNode) + if (type) { + return processFromSchema(type, document, schemaLocation) + } else if (selectedField) { + return ( + (await processResolvers(selectedField, resolverPattern)) ?? + (await processModels(selectedField, modelPattern)) + ) + } else if (modelName) { + return await processModelForName(modelName, modelPattern) + } else if (selectedEnum) { + return await processEnum(selectedEnum, enumPattern) + } else if (enumName) { + return await processEnumForName(enumName, enumPattern) + } } catch (e) { console.error(e) - return null } } diff --git a/src/definition-provider/definition-provider-for-source.test.ts b/src/definition-provider/definition-provider-for-source.test.ts new file mode 100644 index 0000000..9794bfe --- /dev/null +++ b/src/definition-provider/definition-provider-for-source.test.ts @@ -0,0 +1,258 @@ +import * as gql from 'graphql' +import { config } from '../config' +import { Position, Range } from '../diff' +import { parseTSFile } from '../ts' +import { trimSpaces } from '../util/trim-spaces' +import { provideDefinitionForSource } from './definition-provider-for-source' + +const schema = ` +type User { + id: ID! + name: String + address: Address + status: UserStatus +} + +enum UserStatus { + ACTIVE, + DELETED +} + +type Address { + id: String + address: String + city: City +} + +type City { + name: String + code: String +} + +""" +Tweet object +used as primitive type +""" +type Tweet { + tweetId: ID! + status: TweetStatus + mentions: [User!] +} + +enum TweetStatus { + DRAFT + ACTIVE +} + +type Query { + me: User + user(id: ID!): User + tweet(id: ID!): Tweet +} + +type Mutation { + createUser(input: CreateUserInput!): CreateUserResponse + updateUser(id: ID!, name: String): User +} + +type Subscription { + onUserChange(id: ID!): User +} + +input CreateUserInput { + name: String! +} + +type CreateUserResponse { + user: User! +} +` + +function getAt(schema: string, range: Range | null) { + if (!range) return + const lines = schema.split('\n').slice(range.start.line, range.end.line + 1) + return lines.join('\n') +} + +async function process(path: string, code: string, position: Position) { + const location = provideDefinitionForSource( + parseTSFile(path, trimSpaces(code)), + position, + gql.buildSchema(schema), + 'schema.gql', + config, + ) + if (!location) return '' + return [ + `# ${[location.path.replace(__dirname + '/', ''), location.range.start.line, location.range.start.character].join(':')}`, + getAt(schema ?? '', location.range), + ].join('\n') +} + +describe('provideDefinitionForSource', () => { + test('should generate a location in schema for graphql query', async () => { + const code = ` + const query = gql\` + query tweetQuery($id: ID!) { + tweet(id: $id) { + tweetId + status + } + } + \` + ` + const output = await process('user.gql.ts', code, new Position(2, 4)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:42:2 + tweet(id: ID!): Tweet`), + ) + }) + + test('should generate a location in schema for field definition', async () => { + const code = ` + const query = gql\` + query tweetQuery($id: ID!) { + tweet(id: $id) { + tweetId + status + } + } + \` + ` + const output = await process('user.gql.ts', code, new Position(3, 6)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:29:2 + tweetId: ID!`), + ) + }) + + test('should generate a location in schema for variable definition', async () => { + const code = ` + const query = gql\` + mutation createUserMutation($input: CreateUserInput!) { + createUser(input: $input) { + user { + id + } + } + } + \` + ` + const output = await process('create-user-mutation.gql.ts', code, new Position(1, 50)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:54:0 + input CreateUserInput { + name: String! + }`), + ) + }) + + test('should generate a location in schema for operation definition', async () => { + const code = ` + const query = gql\` + mutation createUserMutation($input: CreateUserInput!) { + createUser(input: $input) { + user { + id + } + } + } + \` + ` + const output = await process('create-user-mutation.gql.ts', code, new Position(1, 3)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:45:0 + type Mutation { + createUser(input: CreateUserInput!): CreateUserResponse + updateUser(id: ID!, name: String): User + }`), + ) + }) + + test('should generate a location in schema for a given model name', async () => { + const code = ` + @ObjectType() + class User { + @Field(() => ID) + id!: string + + @Field(() => String, { nullable: true }) + name?: string + } + + ` + const output = await process('user.model.ts', code, new Position(1, 6)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:1:0 + type User { + id: ID! + name: String + address: Address + status: UserStatus + }`), + ) + }) + + test('should generate a location in schema for a given input model', async () => { + const code = ` + @InputType() + class CreateUserInput { + @Field(() => ID) + id!: string + + @Field(() => String, { nullable: true }) + name?: string + } + + ` + const output = await process('create-user.input.ts', code, new Position(1, 6)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:54:0 + input CreateUserInput { + name: String! + }`), + ) + }) + + test('should generate a location in schema for a given response model', async () => { + const code = ` + @ObjectType() + class CreateUserResponse { + @Field() + user!: User + } + ` + const output = await process('create-user.response.ts', code, new Position(1, 6)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:58:0 + type CreateUserResponse { + user: User! + }`), + ) + }) + + test('should generate a location in schema for a given enum name', async () => { + const code = ` + export enum TweetStatus { + DRAFT = 'DRAFT', + ACTIVE = 'ACTIVE', + } + ` + const output = await process('user.enum.ts', code, new Position(0, 12)) + expect(output).toEqual( + trimSpaces(` + # schema.gql:34:0 + enum TweetStatus { + DRAFT + ACTIVE + }`), + ) + }) +}) diff --git a/src/definition-provider/definition-provider-for-source.ts b/src/definition-provider/definition-provider-for-source.ts new file mode 100644 index 0000000..c49902d --- /dev/null +++ b/src/definition-provider/definition-provider-for-source.ts @@ -0,0 +1,154 @@ +import * as gql from 'graphql' +import ts from 'typescript' +import { GQLAssistConfig } from '../config' +import { Location, Position } from '../diff' +import { isEnum, isHook, isInput, isModel } from '../generator' +import { getGQLNodeRange, getGQLNodeRangeWithoutDescription, makeQueryParsable } from '../gql' +import { isPositionWithInRange } from '../position/is-position-within-range' +import { getGQLContent, getGraphQLQueryVariable, getTSNodeLocationRange, hasDecorator } from '../ts' + +function isInRange(node: gql.ASTNode, position: Position, offset?: Position) { + const nodeRange = getGQLNodeRange(node, offset) + return isPositionWithInRange(position, nodeRange, true) +} + +function provideDefinitionForGraphQL( + sourceFile: ts.SourceFile, + position: Position, + schema: gql.GraphQLSchema, + schemaLocation: string, +) { + const variable = getGraphQLQueryVariable(sourceFile) + if (!variable) return null + + const range = getTSNodeLocationRange(variable, sourceFile) + const query = getGQLContent(variable) + if (!query || query?.trim() === '') return null + const offset = new Position(range.start.line, 0) + + try { + const fixed = makeQueryParsable(query) + const document = gql.parse(fixed) + let targetNode: gql.ASTNode | undefined | null + const typeInfo = new gql.TypeInfo(schema) + gql.visit( + document, + gql.visitWithTypeInfo(typeInfo, { + OperationDefinition(node) { + if (!isInRange(node, position, offset)) return + const type = typeInfo.getType() + targetNode = gql.getNamedType(type)?.astNode + }, + VariableDefinition(node) { + if (!isInRange(node, position, offset)) return + const type = typeInfo.getInputType() + targetNode = gql.getNamedType(type)?.astNode + }, + Field(node) { + if (!isInRange(node, position, offset)) return + const type = typeInfo.getParentType() + if (!type || !gql.isObjectType(type)) return + targetNode = type?.getFields()[node.name.value]?.astNode + }, + }), + ) + if (!targetNode) return null + return new Location(schemaLocation, getGQLNodeRangeWithoutDescription(targetNode)) + } catch (e) { + console.error(e) + return null + } +} + +function processClassDeclaration( + classDeclaration: ts.ClassDeclaration, + sourceFile: ts.SourceFile, + position: Position, + schema: gql.GraphQLSchema, + schemaLocation: string, +) { + if ( + !classDeclaration.name || + (!hasDecorator(classDeclaration, 'ObjectType') && !hasDecorator(classDeclaration, 'InputType')) + ) { + return null + } + const range = getTSNodeLocationRange(classDeclaration.name, sourceFile) + const className = classDeclaration.name.getText() + if (range && isPositionWithInRange(position, range, true)) { + const type = schema.getType(className) + if (type?.astNode) { + return new Location(schemaLocation, getGQLNodeRange(type.astNode)) + } + } +} + +function processEnumDeclaration( + enumDeclaration: ts.EnumDeclaration, + sourceFile: ts.SourceFile, + position: Position, + schema: gql.GraphQLSchema, + schemaLocation: string, +) { + if (!enumDeclaration.name) { + return null + } + const range = getTSNodeLocationRange(enumDeclaration.name, sourceFile) + const enumName = enumDeclaration.name.getText() + if (range && isPositionWithInRange(position, range, true)) { + const type = schema.getType(enumName) + if (type?.astNode) { + return new Location(schemaLocation, getGQLNodeRange(type.astNode)) + } + } +} + +function provideDefinitionForClassAndFields( + sourceFile: ts.SourceFile, + position: Position, + schema: gql.GraphQLSchema, + schemaLocation: string, + config: GQLAssistConfig, +) { + if (isModel(sourceFile, config) || isInput(sourceFile, config)) { + for (const statement of sourceFile.statements) { + if (ts.isClassDeclaration(statement)) { + const location = processClassDeclaration( + statement, + sourceFile, + position, + schema, + schemaLocation, + ) + if (location) return location + } + } + } + if (isEnum(sourceFile, config)) { + for (const statement of sourceFile.statements) { + if (ts.isEnumDeclaration(statement)) { + const location = processEnumDeclaration( + statement, + sourceFile, + position, + schema, + schemaLocation, + ) + if (location) return location + } + } + } +} + +export function provideDefinitionForSource( + sourceFile: ts.SourceFile, + position: Position, + schema: gql.GraphQLSchema, + schemaLocation: string, + config: GQLAssistConfig, +) { + if (isHook(sourceFile, config)) { + return provideDefinitionForGraphQL(sourceFile, position, schema, schemaLocation) + } + return provideDefinitionForClassAndFields(sourceFile, position, schema, schemaLocation, config) +} diff --git a/src/definition-provider/definition-provider-for-source.tsx b/src/definition-provider/definition-provider-for-source.tsx deleted file mode 100644 index 2add3aa..0000000 --- a/src/definition-provider/definition-provider-for-source.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import * as gql from 'graphql' -import ts from 'typescript' -import { Position } from '../diff' -import { getGQLNodeRange, getGQLNodeRangeWithoutDescription, makeQueryParsable } from '../gql' -import { isPositionWithInRange } from '../position/is-position-within-range' -import { getGQLContent, getGraphQLQueryVariable, getTSNodeLocationRange } from '../ts' - -function isInRange(node: gql.ASTNode, position: Position, offset?: Position) { - const nodeRange = getGQLNodeRange(node, offset) - return isPositionWithInRange(position, nodeRange, true) -} - -export function provideDefinitionForSource( - sourceFile: ts.SourceFile, - position: Position, - schema: gql.GraphQLSchema, -) { - const variable = getGraphQLQueryVariable(sourceFile) - if (!variable) return null - - const range = getTSNodeLocationRange(variable, sourceFile) - const query = getGQLContent(variable) - if (!query || query?.trim() === '') return null - const offset = new Position(range.start.line, 0) - - try { - const fixed = makeQueryParsable(query) - const document = gql.parse(fixed) - let targetNode: gql.ASTNode | undefined | null - const typeInfo = new gql.TypeInfo(schema) - gql.visit( - document, - gql.visitWithTypeInfo(typeInfo, { - OperationDefinition(node) { - if (!isInRange(node, position, offset)) return - const type = typeInfo.getType() - targetNode = gql.getNamedType(type)?.astNode - }, - VariableDefinition(node) { - if (!isInRange(node, position, offset)) return - const type = typeInfo.getInputType() - targetNode = gql.getNamedType(type)?.astNode - }, - Field(node) { - if (!isInRange(node, position, offset)) return - const type = typeInfo.getParentType() - if (!type || !gql.isObjectType(type)) return - targetNode = type?.getFields()[node.name.value]?.astNode - }, - }), - ) - if (!targetNode) return null - return getGQLNodeRangeWithoutDescription(targetNode) - } catch (e) { - console.error(e) - return null - } -} diff --git a/src/definition-provider/reference-provider-for-schema.ts b/src/definition-provider/reference-provider-for-schema.ts index c7c7233..30e4034 100644 --- a/src/definition-provider/reference-provider-for-schema.ts +++ b/src/definition-provider/reference-provider-for-schema.ts @@ -4,11 +4,7 @@ import { Location, Position } from '../diff' import { getGQLNodeRange, getGQLNodeRangeWithoutDescription, makeQueryParsable } from '../gql' import { isPositionWithInRange } from '../position/is-position-within-range' import { parseGraphQLDocumentFromTS, readTSFile } from '../ts' - -interface SelectedField { - parent: string - name: string -} +import { SelectedField } from './selected-field.type' function isInRange(node: gql.ASTNode, position: Position, offset?: Position) { const nodeRange = getGQLNodeRange(node, offset) diff --git a/src/definition-provider/selected-field.type.ts b/src/definition-provider/selected-field.type.ts new file mode 100644 index 0000000..cdccfc8 --- /dev/null +++ b/src/definition-provider/selected-field.type.ts @@ -0,0 +1,4 @@ +export interface SelectedField { + parent: string + name: string +} diff --git a/src/definition-provider/test/base.model.ts b/src/definition-provider/test/base.model.ts new file mode 100644 index 0000000..3439a90 --- /dev/null +++ b/src/definition-provider/test/base.model.ts @@ -0,0 +1,7 @@ +import { Field, ID, ObjectType } from '@nestjs/graphql' + +@ObjectType() +export class Model { + @Field(() => ID) + id!: string +} diff --git a/src/definition-provider/test/custom.d.ts b/src/definition-provider/test/custom.d.ts new file mode 100644 index 0000000..1b60c84 --- /dev/null +++ b/src/definition-provider/test/custom.d.ts @@ -0,0 +1 @@ +declare module '@nestjs/graphql' diff --git a/src/definition-provider/test/tweet-status.enum.ts b/src/definition-provider/test/tweet-status.enum.ts new file mode 100644 index 0000000..6659e8c --- /dev/null +++ b/src/definition-provider/test/tweet-status.enum.ts @@ -0,0 +1,7 @@ +import { registerEnumType } from '@nestjs/graphql' + +export enum TweetStatus { + DRAFT = 'DRAFT', + ACTIVE = 'ACTIVE', +} +registerEnumType(TweetStatus, { name: 'TweetStatus' }) diff --git a/src/definition-provider/test/tweet.model.ts b/src/definition-provider/test/tweet.model.ts new file mode 100644 index 0000000..18ec120 --- /dev/null +++ b/src/definition-provider/test/tweet.model.ts @@ -0,0 +1,10 @@ +import { Field, ID, ObjectType } from '@nestjs/graphql' +import { User } from './user.model' + +@ObjectType('Tweet') +export class Tweet { + @Field(() => ID) + id!: string + + mentions: [User] +} diff --git a/src/definition-provider/test/tweet.resolver.ts b/src/definition-provider/test/tweet.resolver.ts new file mode 100644 index 0000000..3b18a87 --- /dev/null +++ b/src/definition-provider/test/tweet.resolver.ts @@ -0,0 +1,10 @@ +import { Query, Resolver } from '@nestjs/graphql' +import { Tweet } from './tweet.model' + +@Resolver() +export class TweetResolver { + @Query(() => Tweet) + tweet() { + return null + } +} diff --git a/src/definition-provider/test/user.model.ts b/src/definition-provider/test/user.model.ts new file mode 100644 index 0000000..5adbbd5 --- /dev/null +++ b/src/definition-provider/test/user.model.ts @@ -0,0 +1,8 @@ +import { Field, ID, ObjectType } from '@nestjs/graphql' +import { Model } from './base.model' + +@ObjectType() +export class User extends Model { + @Field({ nullable: true }) + name?: string +} diff --git a/src/diff/token.ts b/src/diff/token.ts index cd723ad..86fa6dd 100644 --- a/src/diff/token.ts +++ b/src/diff/token.ts @@ -22,7 +22,6 @@ export class Range { constructor( public start: Position, public end: Position, - public path?: string, ) {} setStart(start: Position) { diff --git a/src/position/is-position-within-range.ts b/src/position/is-position-within-range.ts index 6aab8f7..6b00a3e 100644 --- a/src/position/is-position-within-range.ts +++ b/src/position/is-position-within-range.ts @@ -1,18 +1,29 @@ +import * as gql from 'graphql' import { Position, Range } from '../diff' +import { getGQLNodeRange } from '../gql' export function isPositionWithInRange(position: Position, range: Range, includeEdges = false) { if (position.line > range.start.line && position.line < range.end.line) { return true } - if (position.line == range.start.line) { + if (position.line === range.start.line && position.line === range.end.line) { + if (includeEdges) { + return ( + position.character >= range.start.character && position.character <= range.end.character + ) + } else { + return position.character > range.start.character && position.character < range.end.character + } + } + if (position.line === range.start.line) { if (includeEdges) { return position.character >= range.start.character } else { return position.character > range.start.character } } - if (position.line == range.end.line) { + if (position.line === range.end.line) { if (includeEdges) { return position.character <= range.end.character } else { @@ -21,3 +32,13 @@ export function isPositionWithInRange(position: Position, range: Range, includeE } return false } + +export function isInRange( + node: gql.ASTNode, + position: Position, + offset?: Position, + includeEdges = true, +) { + const nodeRange = getGQLNodeRange(node, offset) + return isPositionWithInRange(position, nodeRange, includeEdges) +}