Add support for the gqlgen library

This commit is contained in:
Alvaro Muñoz
2023-06-28 15:05:25 +02:00
parent 656b4fc1aa
commit fe4ddab7e4
487 changed files with 174523 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
lgtm,codescanning
* Support for the gqlgen has been added.

View File

@@ -41,6 +41,7 @@ import semmle.go.frameworks.Encoding
import semmle.go.frameworks.Gin
import semmle.go.frameworks.Glog
import semmle.go.frameworks.GoRestfulHttp
import semmle.go.frameworks.Gqlgen
import semmle.go.frameworks.K8sIoApimachineryPkgRuntime
import semmle.go.frameworks.K8sIoApiCoreV1
import semmle.go.frameworks.K8sIoClientGo

View File

@@ -0,0 +1,44 @@
/** Provides models of commonly used functions and types in the gqlgen packages. */
import go
/** Provides models of commonly used functions and types in the gqlgen packages. */
module Gqlgen {
class GqlgenGeneratedFile extends File {
GqlgenGeneratedFile() {
exists(DataFlow::CallNode call |
call.getReceiver().getType().hasQualifiedName("github.com/99designs/gqlgen/graphql", _) and
call.getFile() = this
)
}
}
class ResolverInterface extends Type {
ResolverInterface() {
this.getQualifiedName().matches("%Resolver") and
this.getEntity().getDeclaration().getFile() instanceof GqlgenGeneratedFile
}
}
class ResolverInterfaceMethod extends Method {
ResolverInterfaceMethod() {
this.getReceiver().getType() instanceof ResolverInterface
}
}
class ResolverImplementationMethod extends Method {
ResolverImplementationMethod() { this.implements(any(ResolverInterfaceMethod r)) }
Parameter getAnUntrustedParameter() {
result.getFunction() = this.getFuncDecl() and
not result.getType().hasQualifiedName("context", "Context") and
result.getIndex() > 0
}
}
class ResolverParameter extends UntrustedFlowSource::Range instanceof DataFlow::ParameterNode {
ResolverParameter() {
this.asParameter() = any(ResolverImplementationMethod h).getAnUntrustedParameter()
}
}
}

View File

@@ -0,0 +1,24 @@
module pwntester/gqlgen-todos
go 1.19
require (
github.com/99designs/gqlgen v0.17.34
github.com/vektah/gqlparser/v2 v2.5.4
)
require (
github.com/agnivade/levenshtein v1.1.1 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.3 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/urfave/cli/v2 v2.25.5 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/tools v0.9.3 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -0,0 +1,5 @@
import go
import semmle.go.frameworks.Gqlgen
from Gqlgen::ResolverParameter p
select p

View File

@@ -0,0 +1,87 @@
# Where are all the schema files located? globs are supported eg src/**/*.graphqls
schema:
- graph/*.graphqls
# Where should the generated server code go?
exec:
filename: graph/generated.go
package: graph
# Uncomment to enable federation
# federation:
# filename: graph/federation.go
# package: graph
# Where should any generated models go?
model:
filename: graph/model/models_gen.go
package: model
# Where should the resolver implementations go?
resolver:
layout: follow-schema
dir: graph
package: graph
filename_template: "{name}.resolvers.go"
# Optional: turn on to not generate template comments above resolvers
# omit_template_comment: false
# Optional: turn on use ` + "`" + `gqlgen:"fieldName"` + "`" + ` tags in your models
# struct_tag: json
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false
# Optional: turn on to omit Is<Name>() methods to interface and unions
# omit_interface_checks : true
# Optional: turn on to skip generation of ComplexityRoot struct content and Complexity function
# omit_complexity: false
# Optional: turn on to not generate any file notice comments in generated files
# omit_gqlgen_file_notice: false
# Optional: turn on to exclude the gqlgen version in the generated file notice. No effect if `omit_gqlgen_file_notice` is true.
# omit_gqlgen_version_in_file_notice: false
# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true
# Optional: turn on to return pointers instead of values in unmarshalInput
# return_pointers_in_unmarshalinput: false
# Optional: wrap nullable input fields with Omittable
# nullable_input_omittable: true
# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true
# Optional: set to skip running `go mod tidy` when generating server code
# skip_mod_tidy: true
# gqlgen will search for any type names in the schema in these go packages
# if they match it will use them, otherwise it will generate them.
autobind:
# - "pwntester/gqlgen-todos/graph/model"
# This section declares type mapping between the GraphQL and go type systems
#
# The first line in each type will be used as defaults for resolver arguments and
# modelgen, the others will be allowed when binding to fields. Configure them to
# your liking
models:
ID:
model:
- github.com/99designs/gqlgen/graphql.ID
- github.com/99designs/gqlgen/graphql.Int
- github.com/99designs/gqlgen/graphql.Int64
- github.com/99designs/gqlgen/graphql.Int32
Int:
model:
- github.com/99designs/gqlgen/graphql.Int
- github.com/99designs/gqlgen/graphql.Int64
- github.com/99designs/gqlgen/graphql.Int32

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.
package model
type NewTodo struct {
Text string `json:"text"`
UserID string `json:"userId"`
}
type Todo struct {
ID string `json:"id"`
Text string `json:"text"`
Done bool `json:"done"`
User *User `json:"user"`
}
type User struct {
ID string `json:"id"`
Name string `json:"name"`
}

View File

@@ -0,0 +1,7 @@
package graph
// This file will not be regenerated automatically.
//
// It serves as dependency injection for your app, add any dependencies you require here.
type Resolver struct{}

View File

@@ -0,0 +1,28 @@
# GraphQL schema example
#
# https://gqlgen.com/getting-started/
type Todo {
id: ID!
text: String!
done: Boolean!
user: User!
}
type User {
id: ID!
name: String!
}
type Query {
todos: [Todo!]!
}
input NewTodo {
text: String!
userId: String!
}
type Mutation {
createTodo(input: NewTodo!): Todo!
}

View File

@@ -0,0 +1,30 @@
package graph
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.
// Code generated by github.com/99designs/gqlgen version v0.17.34
import (
"context"
"fmt"
"pwntester/gqlgen-todos/graph/model"
)
// CreateTodo is the resolver for the createTodo field.
func (r *mutationResolver) CreateTodo(ctx context.Context, input model.NewTodo) (*model.Todo, error) {
panic(fmt.Errorf("not implemented: CreateTodo - createTodo"))
}
// Todos is the resolver for the todos field.
func (r *queryResolver) Todos(ctx context.Context) ([]*model.Todo, error) {
panic(fmt.Errorf("not implemented: Todos - todos"))
}
// Mutation returns MutationResolver implementation.
func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} }
// Query returns QueryResolver implementation.
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
type mutationResolver struct{ *Resolver }
type queryResolver struct{ *Resolver }

View File

@@ -0,0 +1,28 @@
package main
import (
"log"
"net/http"
"os"
"pwntester/gqlgen-todos/graph"
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/playground"
)
const defaultPort = "8080"
func main() {
port := os.Getenv("PORT")
if port == "" {
port = defaultPort
}
srv := handler.NewDefaultServer(graph.NewExecutableSchema(graph.Config{Resolvers: &graph.Resolver{}}))
http.Handle("/", playground.Handler("GraphQL playground", "/query"))
http.Handle("/query", srv)
log.Printf("connect to http://localhost:%s/ for GraphQL playground", port)
log.Fatal(http.ListenAndServe(":"+port, nil))
}

View File

@@ -0,0 +1,5 @@
package tools
import (
_ "github.com/99designs/gqlgen"
)

View File

@@ -0,0 +1,3 @@
/**/node_modules
/codegen/tests/gen
/vendor

View File

@@ -0,0 +1,20 @@
root = true
[*]
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
indent_style = space
indent_size = 4
[*.{go,gotpl}]
indent_style = tab
[*.yml]
indent_size = 2
# These often end up with go code inside, so lets keep tabs
[*.{html,md}]
indent_size = 2
indent_style = tab

View File

@@ -0,0 +1,3 @@
/codegen/templates/data.go linguist-generated
/_examples/dataloader/*_gen.go linguist-generated
generated.go linguist-generated

View File

@@ -0,0 +1,18 @@
/vendor
/docs/public
/docs/.hugo_build.lock
/_examples/chat/node_modules
/integration/node_modules
/integration/schema-fetched.graphql
/_examples/chat/package-lock.json
/_examples/federation/package-lock.json
/_examples/federation/node_modules
/codegen/gen
/gen
/.vscode
.idea/
*.test
*.out
gqlgen
*.exe

View File

@@ -0,0 +1,35 @@
run:
tests: true
skip-dirs:
- bin
linters-settings:
errcheck:
ignore: fmt:.*,[rR]ead|[wW]rite|[cC]lose,io:Copy
linters:
disable-all: true
enable:
- bodyclose
- dupl
- errcheck
- gocritic
- gofmt
- goimports
- gosimple
- govet
- ineffassign
- misspell
- nakedret
- prealloc
- staticcheck
- typecheck
- unconvert
- unused
issues:
exclude-rules:
# Exclude some linters from running on tests files.
- path: _test\.go
linters:
- dupl

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
# Contribution Guidelines
Want to contribute to gqlgen? Here are some guidelines for how we accept help.
## Getting in Touch
Our [discord](https://discord.gg/DYEq3EMs4U) server is the best place to ask questions or get advice on using gqlgen.
## Reporting Bugs and Issues
We use [GitHub Issues](https://github.com/99designs/gqlgen/issues) to track bugs, so please do a search before submitting to ensure your problem isn't already tracked.
### New Issues
Please provide the expected and observed behaviours in your issue. A minimal GraphQL schema or configuration file should be provided where appropriate.
## Proposing a Change
If you intend to implement a feature for gqlgen, or make a non-trivial change to the current implementation, we recommend [first filing an issue](https://github.com/99designs/gqlgen/issues/new) marked with the `proposal` tag, so that the engineering team can provide guidance and feedback on the direction of an implementation. This also help ensure that other people aren't also working on the same thing.
Bug fixes are welcome and should come with appropriate test coverage.
New features should be made against the `next` branch.
### License
By contributing to gqlgen, you agree that your contributions will be licensed under its MIT license.

View File

@@ -0,0 +1,19 @@
Copyright (c) 2020 gqlgen authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,162 @@
![gqlgen](https://user-images.githubusercontent.com/980499/133180111-d064b38c-6eb9-444b-a60f-7005a6e68222.png)
# gqlgen [![Integration](https://github.com/99designs/gqlgen/actions/workflows/integration.yml/badge.svg)](https://github.com/99designs/gqlgen/actions) [![Coverage Status](https://coveralls.io/repos/github/99designs/gqlgen/badge.svg?branch=master)](https://coveralls.io/github/99designs/gqlgen?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/99designs/gqlgen)](https://goreportcard.com/report/github.com/99designs/gqlgen) [![Go Reference](https://pkg.go.dev/badge/github.com/99designs/gqlgen.svg)](https://pkg.go.dev/github.com/99designs/gqlgen) [![Read the Docs](https://badgen.net/badge/docs/available/green)](http://gqlgen.com/)
## What is gqlgen?
[gqlgen](https://github.com/99designs/gqlgen) is a Go library for building GraphQL servers without any fuss.<br/>
- **gqlgen is based on a Schema first approach** — You get to Define your API using the GraphQL [Schema Definition Language](http://graphql.org/learn/schema/).
- **gqlgen prioritizes Type safety** — You should never see `map[string]interface{}` here.
- **gqlgen enables Codegen** — We generate the boring bits, so you can focus on building your app quickly.
Still not convinced enough to use **gqlgen**? Compare **gqlgen** with other Go graphql [implementations](https://gqlgen.com/feature-comparison/)
## Quick start
1. [Initialise a new go module](https://golang.org/doc/tutorial/create-module)
mkdir example
cd example
go mod init example
2. Add `github.com/99designs/gqlgen` to your [project's tools.go](https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module)
printf '// +build tools\npackage tools\nimport (_ "github.com/99designs/gqlgen"\n _ "github.com/99designs/gqlgen/graphql/introspection")' | gofmt > tools.go
go mod tidy
3. Initialise gqlgen config and generate models
go run github.com/99designs/gqlgen init
4. Start the graphql server
go run server.go
More help to get started:
- [Getting started tutorial](https://gqlgen.com/getting-started/) - a comprehensive guide to help you get started
- [Real-world examples](https://github.com/99designs/gqlgen/tree/master/_examples) show how to create GraphQL applications
- [Reference docs](https://pkg.go.dev/github.com/99designs/gqlgen) for the APIs
## Reporting Issues
If you think you've found a bug, or something isn't behaving the way you think it should, please raise an [issue](https://github.com/99designs/gqlgen/issues) on GitHub.
## Contributing
We welcome contributions, Read our [Contribution Guidelines](https://github.com/99designs/gqlgen/blob/master/CONTRIBUTING.md) to learn more about contributing to **gqlgen**
## Frequently asked questions
### How do I prevent fetching child objects that might not be used?
When you have nested or recursive schema like this:
```graphql
type User {
id: ID!
name: String!
friends: [User!]!
}
```
You need to tell gqlgen that it should only fetch friends if the user requested it. There are two ways to do this;
- #### Using Custom Models
Write a custom model that omits the friends field:
```go
type User struct {
ID int
Name string
}
```
And reference the model in `gqlgen.yml`:
```yaml
# gqlgen.yml
models:
User:
model: github.com/you/pkg/model.User # go import path to the User struct above
```
- #### Using Explicit Resolvers
If you want to Keep using the generated model, mark the field as requiring a resolver explicitly in `gqlgen.yml` like this:
```yaml
# gqlgen.yml
models:
User:
fields:
friends:
resolver: true # force a resolver to be generated
```
After doing either of the above and running generate we will need to provide a resolver for friends:
```go
func (r *userResolver) Friends(ctx context.Context, obj *User) ([]*User, error) {
// select * from user where friendid = obj.ID
return friends, nil
}
```
You can also use inline config with directives to achieve the same result
```graphql
directive @goModel(model: String, models: [String!]) on OBJECT
| INPUT_OBJECT
| SCALAR
| ENUM
| INTERFACE
| UNION
directive @goField(forceResolver: Boolean, name: String, omittable: Boolean) on INPUT_FIELD_DEFINITION
| FIELD_DEFINITION
type User @goModel(model: "github.com/you/pkg/model.User") {
id: ID! @goField(name: "todoId")
friends: [User!]! @goField(forceResolver: true)
}
```
### Can I change the type of the ID from type String to Type Int?
Yes! You can by remapping it in config as seen below:
```yaml
models:
ID: # The GraphQL type ID is backed by
model:
- github.com/99designs/gqlgen/graphql.IntID # a go integer
- github.com/99designs/gqlgen/graphql.ID # or a go string
```
This means gqlgen will be able to automatically bind to strings or ints for models you have written yourself, but the
first model in this list is used as the default type and it will always be used when:
- Generating models based on schema
- As arguments in resolvers
There isn't any way around this, gqlgen has no way to know what you want in a given context.
### Why do my interfaces have getters? Can I disable these?
These were added in v0.17.14 to allow accessing common interface fields without casting to a concrete type.
However, certain fields, like Relay-style Connections, cannot be implemented with simple getters.
If you'd prefer to not have getters generated in your interfaces, you can add the following in your `gqlgen.yml`:
```yaml
# gqlgen.yml
omit_getters: true
```
## Other Resources
- [Christopher Biscardi @ Gophercon UK 2018](https://youtu.be/FdURVezcdcw)
- [Introducing gqlgen: a GraphQL Server Generator for Go](https://99designs.com.au/blog/engineering/gqlgen-a-graphql-server-generator-for-go/)
- [Dive into GraphQL by Iván Corrales Solera](https://medium.com/@ivan.corrales.solera/dive-into-graphql-9bfedf22e1a)
- [Sample Project built on gqlgen with Postgres by Oleg Shalygin](https://github.com/oshalygin/gqlgen-pg-todo-example)
- [Hackernews GraphQL Server with gqlgen by Shayegan Hooshyari](https://www.howtographql.com/graphql-go/0-introduction/)

View File

@@ -0,0 +1,15 @@
# When gqlgen gets released, the following things need to happen
Assuming the next version is $NEW_VERSION=v0.16.0 or something like that.
1. Run the https://github.com/99designs/gqlgen/blob/master/bin/release:
```
./bin/release $NEW_VERSION
```
2. git-chglog -o CHANGELOG.md
3. go generate ./...; cd _examples; go generate ./...; cd ..
4. git commit and push the CHANGELOG.md
5. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
6. Comment on the release discussion with any really important notes (breaking changes)
I used https://github.com/git-chglog/git-chglog to automate the changelog maintenance process for now. We could just as easily use go releaser to make the whole thing automated.

View File

@@ -0,0 +1,39 @@
How to write tests for gqlgen
===
Testing generated code is a little tricky, heres how its currently set up.
### Testing responses from a server
There is a server in `codegen/testserver` that is generated as part
of `go generate ./...`, and tests written against it.
There are also a bunch of tests in against the examples, feel free to take examples from there.
### Testing the errors generated by the binary
These tests are **really** slow, because they need to run the whole codegen step. Use them very sparingly. If you can, find a way to unit test it instead.
Take a look at `codegen/testserver/input_test.go` for an example.
### Testing introspection
Introspection is tested by diffing the output of `graphql get-schema` against an expected output.
Setting up the integration environment is a little tricky:
```bash
cd integration
go generate ./...
go run ./server/server.go
```
in another terminal
```bash
cd integration
npm install
./node_modules/.bin/graphql-codegen
```
will write the schema to `integration/schema-fetched.graphql`, compare that with `schema-expected.graphql`
CI will run this and fail the build if the two files don't match.

View File

@@ -0,0 +1,143 @@
package api
import (
"fmt"
"regexp"
"syscall"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin"
"github.com/99designs/gqlgen/plugin/federation"
"github.com/99designs/gqlgen/plugin/modelgen"
"github.com/99designs/gqlgen/plugin/resolvergen"
)
func Generate(cfg *config.Config, option ...Option) error {
_ = syscall.Unlink(cfg.Exec.Filename)
if cfg.Model.IsDefined() {
_ = syscall.Unlink(cfg.Model.Filename)
}
plugins := []plugin.Plugin{}
if cfg.Model.IsDefined() {
plugins = append(plugins, modelgen.New())
}
plugins = append(plugins, resolvergen.New())
if cfg.Federation.IsDefined() {
if cfg.Federation.Version == 0 { // default to using the user's choice of version, but if unset, try to sort out which federation version to use
urlRegex := regexp.MustCompile(`(?s)@link.*\(.*url:.*?"(.*?)"[^)]+\)`) // regex to grab the url of a link directive, should it exist
// check the sources, and if one is marked as federation v2, we mark the entirety to be generated using that format
for _, v := range cfg.Sources {
cfg.Federation.Version = 1
urlString := urlRegex.FindStringSubmatch(v.Input)
if urlString != nil && urlString[1] == "https://specs.apollo.dev/federation/v2.0" {
cfg.Federation.Version = 2
break
}
}
}
plugins = append([]plugin.Plugin{federation.New(cfg.Federation.Version)}, plugins...)
}
for _, o := range option {
o(cfg, &plugins)
}
for _, p := range plugins {
if inj, ok := p.(plugin.EarlySourceInjector); ok {
if s := inj.InjectSourceEarly(); s != nil {
cfg.Sources = append(cfg.Sources, s)
}
}
}
if err := cfg.LoadSchema(); err != nil {
return fmt.Errorf("failed to load schema: %w", err)
}
for _, p := range plugins {
if inj, ok := p.(plugin.LateSourceInjector); ok {
if s := inj.InjectSourceLate(cfg.Schema); s != nil {
cfg.Sources = append(cfg.Sources, s)
}
}
}
// LoadSchema again now we have everything
if err := cfg.LoadSchema(); err != nil {
return fmt.Errorf("failed to load schema: %w", err)
}
if err := cfg.Init(); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
for _, p := range plugins {
if mut, ok := p.(plugin.ConfigMutator); ok {
err := mut.MutateConfig(cfg)
if err != nil {
return fmt.Errorf("%s: %w", p.Name(), err)
}
}
}
// Merge again now that the generated models have been injected into the typemap
data_plugins := make([]interface{}, len(plugins))
for index := range plugins {
data_plugins[index] = plugins[index]
}
data, err := codegen.BuildData(cfg, data_plugins...)
if err != nil {
return fmt.Errorf("merging type systems failed: %w", err)
}
if err = codegen.GenerateCode(data); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
if !cfg.SkipModTidy {
if err = cfg.Packages.ModTidy(); err != nil {
return fmt.Errorf("tidy failed: %w", err)
}
}
for _, p := range plugins {
if mut, ok := p.(plugin.CodeGenerator); ok {
err := mut.GenerateCode(data)
if err != nil {
return fmt.Errorf("%s: %w", p.Name(), err)
}
}
}
if err = codegen.GenerateCode(data); err != nil {
return fmt.Errorf("generating core failed: %w", err)
}
if !cfg.SkipValidation {
if err := validate(cfg); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
}
return nil
}
func validate(cfg *config.Config) error {
roots := []string{cfg.Exec.ImportPath()}
if cfg.Model.IsDefined() {
roots = append(roots, cfg.Model.ImportPath())
}
if cfg.Resolver.IsDefined() {
roots = append(roots, cfg.Resolver.ImportPath())
}
cfg.Packages.LoadAll(roots...)
errs := cfg.Packages.Errors()
if len(errs) > 0 {
return errs
}
return nil
}

View File

@@ -0,0 +1,47 @@
package api
import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin"
)
type Option func(cfg *config.Config, plugins *[]plugin.Plugin)
func NoPlugins() Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = nil
}
}
func AddPlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = append(*plugins, p)
}
}
// PrependPlugin prepends plugin any existing plugins
func PrependPlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
*plugins = append([]plugin.Plugin{p}, *plugins...)
}
}
// ReplacePlugin replaces any existing plugin with a matching plugin name
func ReplacePlugin(p plugin.Plugin) Option {
return func(cfg *config.Config, plugins *[]plugin.Plugin) {
if plugins != nil {
found := false
ps := *plugins
for i, o := range ps {
if p.Name() == o.Name() {
ps[i] = p
found = true
}
}
if !found {
ps = append(ps, p)
}
*plugins = ps
}
}
}

View File

@@ -0,0 +1,122 @@
package codegen
import (
"fmt"
"go/types"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
type ArgSet struct {
Args []*FieldArgument
FuncDecl string
}
type FieldArgument struct {
*ast.ArgumentDefinition
TypeReference *config.TypeReference
VarName string // The name of the var in go
Object *Object // A link back to the parent object
Default interface{} // The default value
Directives []*Directive
Value interface{} // value set in Data
}
// ImplDirectives get not Builtin and location ARGUMENT_DEFINITION directive
func (f *FieldArgument) ImplDirectives() []*Directive {
d := make([]*Directive, 0)
for i := range f.Directives {
if !f.Directives[i].Builtin && f.Directives[i].IsLocation(ast.LocationArgumentDefinition) {
d = append(d, f.Directives[i])
}
}
return d
}
func (f *FieldArgument) DirectiveObjName() string {
return "rawArgs"
}
func (f *FieldArgument) Stream() bool {
return f.Object != nil && f.Object.Stream
}
func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgument, error) {
tr, err := b.Binder.TypeReference(arg.Type, nil)
if err != nil {
return nil, err
}
argDirs, err := b.getDirectives(arg.Directives)
if err != nil {
return nil, err
}
newArg := FieldArgument{
ArgumentDefinition: arg,
TypeReference: tr,
Object: obj,
VarName: templates.ToGoPrivate(arg.Name),
Directives: argDirs,
}
if arg.DefaultValue != nil {
newArg.Default, err = arg.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value is not valid: %w", err)
}
}
return &newArg, nil
}
func (b *builder) bindArgs(field *Field, sig *types.Signature, params *types.Tuple) ([]*FieldArgument, error) {
n := params.Len()
newArgs := make([]*FieldArgument, 0, len(field.Args))
// Accept variadic methods (i.e. have optional parameters).
if params.Len() > len(field.Args) && sig.Variadic() {
n = len(field.Args)
}
nextArg:
for j := 0; j < n; j++ {
param := params.At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.Name, param.Name()) {
tr, err := b.Binder.TypeReference(oldArg.Type, param.Type())
if err != nil {
return nil, err
}
oldArg.TypeReference = tr
newArgs = append(newArgs, oldArg)
continue nextArg
}
}
// no matching arg found, abort
return nil, fmt.Errorf("arg %s not in schema", param.Name())
}
return newArgs, nil
}
func (d *Data) Args() map[string][]*FieldArgument {
ret := map[string][]*FieldArgument{}
for _, o := range d.Objects {
for _, f := range o.Fields {
if len(f.Args) > 0 {
ret[f.ArgsFunc()] = f.Args
}
}
}
for _, directive := range d.Directives() {
if len(directive.Args) > 0 {
ret[directive.ArgsFunc()] = directive.Args
}
}
return ret
}

View File

@@ -0,0 +1,36 @@
{{ range $name, $args := .Args }}
func (ec *executionContext) {{ $name }}(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) {
var err error
args := map[string]interface{}{}
{{- range $i, $arg := . }}
var arg{{$i}} {{ $arg.TypeReference.GO | ref}}
if tmp, ok := rawArgs[{{$arg.Name|quote}}]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField({{$arg.Name|quote}}))
{{- if $arg.ImplDirectives }}
directive0 := func(ctx context.Context) (interface{}, error) { return ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, tmp) }
{{ template "implDirectives" $arg }}
tmp, err = directive{{$arg.ImplDirectives|len}}(ctx)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $arg.TypeReference.GO | ref }}) ; ok {
arg{{$i}} = data
{{- if $arg.TypeReference.IsNilable }}
} else if tmp == nil {
arg{{$i}} = nil
{{- end }}
} else {
return nil, graphql.ErrorOnPath(ctx, fmt.Errorf(`unexpected type %T from directive, should be {{ $arg.TypeReference.GO }}`, tmp))
}
{{- else }}
arg{{$i}}, err = ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, tmp)
if err != nil {
return nil, err
}
{{- end }}
}
args[{{$arg.Name|quote}}] = arg{{$i}}
{{- end }}
return args, nil
}
{{ end }}

View File

@@ -0,0 +1,11 @@
package codegen
func (o *Object) UniqueFields() map[string][]*Field {
m := map[string][]*Field{}
for _, f := range o.Fields {
m[f.GoFieldName] = append(m[f.GoFieldName], f)
}
return m
}

View File

@@ -0,0 +1,580 @@
package config
import (
"errors"
"fmt"
"go/token"
"go/types"
"strings"
"golang.org/x/tools/go/packages"
"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/v2/ast"
)
var ErrTypeNotFound = errors.New("unable to find type")
// Binder connects graphql types to golang types using static analysis
type Binder struct {
pkgs *code.Packages
schema *ast.Schema
cfg *Config
tctx *types.Context
References []*TypeReference
SawInvalid bool
objectCache map[string]map[string]types.Object
}
func (c *Config) NewBinder() *Binder {
return &Binder{
pkgs: c.Packages,
schema: c.Schema,
cfg: c,
}
}
func (b *Binder) TypePosition(typ types.Type) token.Position {
named, isNamed := typ.(*types.Named)
if !isNamed {
return token.Position{
Filename: "unknown",
}
}
return b.ObjectPosition(named.Obj())
}
func (b *Binder) ObjectPosition(typ types.Object) token.Position {
if typ == nil {
return token.Position{
Filename: "unknown",
}
}
pkg := b.pkgs.Load(typ.Pkg().Path())
return pkg.Fset.Position(typ.Pos())
}
func (b *Binder) FindTypeFromName(name string) (types.Type, error) {
pkgName, typeName := code.PkgAndType(name)
return b.FindType(pkgName, typeName)
}
func (b *Binder) FindType(pkgName string, typeName string) (types.Type, error) {
if pkgName == "" {
if typeName == "map[string]interface{}" {
return MapType, nil
}
if typeName == "interface{}" {
return InterfaceType, nil
}
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
if fun, isFunc := obj.(*types.Func); isFunc {
return fun.Type().(*types.Signature).Params().At(0).Type(), nil
}
return obj.Type(), nil
}
func (b *Binder) InstantiateType(orig types.Type, targs []types.Type) (types.Type, error) {
if b.tctx == nil {
b.tctx = types.NewContext()
}
return types.Instantiate(b.tctx, orig, targs, false)
}
var (
MapType = types.NewMap(types.Typ[types.String], types.NewInterfaceType(nil, nil).Complete())
InterfaceType = types.NewInterfaceType(nil, nil)
)
func (b *Binder) DefaultUserObject(name string) (types.Type, error) {
models := b.cfg.Models[name].Model
if len(models) == 0 {
return nil, fmt.Errorf(name + " not found in typemap")
}
if models[0] == "map[string]interface{}" {
return MapType, nil
}
if models[0] == "interface{}" {
return InterfaceType, nil
}
pkgName, typeName := code.PkgAndType(models[0])
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", name)
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
return obj.Type(), nil
}
func (b *Binder) FindObject(pkgName string, typeName string) (types.Object, error) {
if pkgName == "" {
return nil, fmt.Errorf("package cannot be nil")
}
pkg := b.pkgs.LoadWithTypes(pkgName)
if pkg == nil {
err := b.pkgs.Errors()
if err != nil {
return nil, fmt.Errorf("package could not be loaded: %s.%s: %w", pkgName, typeName, err)
}
return nil, fmt.Errorf("required package was not loaded: %s.%s", pkgName, typeName)
}
if b.objectCache == nil {
b.objectCache = make(map[string]map[string]types.Object, b.pkgs.Count())
}
defsIndex, ok := b.objectCache[pkgName]
if !ok {
defsIndex = indexDefs(pkg)
b.objectCache[pkgName] = defsIndex
}
// function based marshalers take precedence
if val, ok := defsIndex["Marshal"+typeName]; ok {
return val, nil
}
if val, ok := defsIndex[typeName]; ok {
return val, nil
}
return nil, fmt.Errorf("%w: %s.%s", ErrTypeNotFound, pkgName, typeName)
}
func indexDefs(pkg *packages.Package) map[string]types.Object {
res := make(map[string]types.Object)
scope := pkg.Types.Scope()
for astNode, def := range pkg.TypesInfo.Defs {
// only look at defs in the top scope
if def == nil {
continue
}
parent := def.Parent()
if parent == nil || parent != scope {
continue
}
if _, ok := res[astNode.Name]; !ok {
// The above check may not be really needed, it is only here to have a consistent behavior with
// previous implementation of FindObject() function which only honored the first inclusion of a def.
// If this is still needed, we can consider something like sync.Map.LoadOrStore() to avoid two lookups.
res[astNode.Name] = def
}
}
return res
}
func (b *Binder) PointerTo(ref *TypeReference) *TypeReference {
newRef := *ref
newRef.GO = types.NewPointer(ref.GO)
b.References = append(b.References, &newRef)
return &newRef
}
// TypeReference is used by args and field types. The Definition can refer to both input and output types.
type TypeReference struct {
Definition *ast.Definition
GQL *ast.Type
GO types.Type // Type of the field being bound. Could be a pointer or a value type of Target.
Target types.Type // The actual type that we know how to bind to. May require pointer juggling when traversing to fields.
CastType types.Type // Before calling marshalling functions cast from/to this base type
Marshaler *types.Func // When using external marshalling functions this will point to the Marshal function
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsOmittable bool // Is the type wrapped with Omittable
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
PointersInUmarshalInput bool // Inverse values and pointers in return.
}
func (ref *TypeReference) Elem() *TypeReference {
if p, isPtr := ref.GO.(*types.Pointer); isPtr {
newRef := *ref
newRef.GO = p.Elem()
return &newRef
}
if ref.IsSlice() {
newRef := *ref
newRef.GO = ref.GO.(*types.Slice).Elem()
newRef.GQL = ref.GQL.Elem
return &newRef
}
return nil
}
func (ref *TypeReference) IsPtr() bool {
_, isPtr := ref.GO.(*types.Pointer)
return isPtr
}
// fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
func (ref *TypeReference) IsPtrToPtr() bool {
if p, isPtr := ref.GO.(*types.Pointer); isPtr {
_, isPtr := p.Elem().(*types.Pointer)
return isPtr
}
return false
}
func (ref *TypeReference) IsNilable() bool {
return IsNilable(ref.GO)
}
func (ref *TypeReference) IsSlice() bool {
_, isSlice := ref.GO.(*types.Slice)
return ref.GQL.Elem != nil && isSlice
}
func (ref *TypeReference) IsPtrToSlice() bool {
if ref.IsPtr() {
_, isPointerToSlice := ref.GO.(*types.Pointer).Elem().(*types.Slice)
return isPointerToSlice
}
return false
}
func (ref *TypeReference) IsPtrToIntf() bool {
if ref.IsPtr() {
_, isPointerToInterface := ref.GO.(*types.Pointer).Elem().(*types.Interface)
return isPointerToInterface
}
return false
}
func (ref *TypeReference) IsNamed() bool {
_, isSlice := ref.GO.(*types.Named)
return isSlice
}
func (ref *TypeReference) IsStruct() bool {
_, isStruct := ref.GO.Underlying().(*types.Struct)
return isStruct
}
func (ref *TypeReference) IsScalar() bool {
return ref.Definition.Kind == ast.Scalar
}
func (ref *TypeReference) UniquenessKey() string {
nullability := "O"
if ref.GQL.NonNull {
nullability = "N"
}
elemNullability := ""
if ref.GQL.Elem != nil && ref.GQL.Elem.NonNull {
// Fix for #896
elemNullability = "ᚄ"
}
return nullability + ref.Definition.Name + "2" + TypeIdentifier(ref.GO) + elemNullability
}
func (ref *TypeReference) MarshalFunc() string {
if ref.Definition == nil {
panic(errors.New("Definition missing for " + ref.GQL.Name()))
}
if ref.Definition.Kind == ast.InputObject {
return ""
}
return "marshal" + ref.UniquenessKey()
}
func (ref *TypeReference) UnmarshalFunc() string {
if ref.Definition == nil {
panic(errors.New("Definition missing for " + ref.GQL.Name()))
}
if !ref.Definition.IsInputType() {
return ""
}
return "unmarshal" + ref.UniquenessKey()
}
func (ref *TypeReference) IsTargetNilable() bool {
return IsNilable(ref.Target)
}
func (b *Binder) PushRef(ret *TypeReference) {
b.References = append(b.References, ret)
}
func isMap(t types.Type) bool {
if t == nil {
return true
}
_, ok := t.(*types.Map)
return ok
}
func isIntf(t types.Type) bool {
if t == nil {
return true
}
_, ok := t.(*types.Interface)
return ok
}
func unwrapOmittable(t types.Type) (types.Type, bool) {
if t == nil {
return t, false
}
named, ok := t.(*types.Named)
if !ok {
return t, false
}
if named.Origin().String() != "github.com/99designs/gqlgen/graphql.Omittable[T any]" {
return t, false
}
return named.TypeArgs().At(0), true
}
func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret *TypeReference, err error) {
if innerType, ok := unwrapOmittable(bindTarget); ok {
if schemaType.NonNull {
return nil, fmt.Errorf("%s is wrapped with Omittable but non-null", schemaType.Name())
}
ref, err := b.TypeReference(schemaType, innerType)
if err != nil {
return nil, err
}
ref.IsOmittable = true
return ref, err
}
if !isValid(bindTarget) {
b.SawInvalid = true
return nil, fmt.Errorf("%s has an invalid type", schemaType.Name())
}
var pkgName, typeName string
def := b.schema.Types[schemaType.Name()]
defer func() {
if err == nil && ret != nil {
b.PushRef(ret)
}
}()
if len(b.cfg.Models[schemaType.Name()].Model) == 0 {
return nil, fmt.Errorf("%s was not found", schemaType.Name())
}
for _, model := range b.cfg.Models[schemaType.Name()].Model {
if model == "map[string]interface{}" {
if !isMap(bindTarget) {
continue
}
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: MapType,
}, nil
}
if model == "interface{}" {
if !isIntf(bindTarget) {
continue
}
return &TypeReference{
Definition: def,
GQL: schemaType,
GO: InterfaceType,
}, nil
}
pkgName, typeName = code.PkgAndType(model)
if pkgName == "" {
return nil, fmt.Errorf("missing package name for %s", schemaType.Name())
}
ref := &TypeReference{
Definition: def,
GQL: schemaType,
}
obj, err := b.FindObject(pkgName, typeName)
if err != nil {
return nil, err
}
if fun, isFunc := obj.(*types.Func); isFunc {
ref.GO = fun.Type().(*types.Signature).Params().At(0).Type()
ref.IsContext = fun.Type().(*types.Signature).Results().At(0).Type().String() == "github.com/99designs/gqlgen/graphql.ContextMarshaler"
ref.Marshaler = fun
ref.Unmarshaler = types.NewFunc(0, fun.Pkg(), "Unmarshal"+typeName, nil)
} else if hasMethod(obj.Type(), "MarshalGQLContext") && hasMethod(obj.Type(), "UnmarshalGQLContext") {
ref.GO = obj.Type()
ref.IsContext = true
ref.IsMarshaler = true
} else if hasMethod(obj.Type(), "MarshalGQL") && hasMethod(obj.Type(), "UnmarshalGQL") {
ref.GO = obj.Type()
ref.IsMarshaler = true
} else if underlying := basicUnderlying(obj.Type()); def.IsLeafType() && underlying != nil && underlying.Kind() == types.String {
// TODO delete before v1. Backwards compatibility case for named types wrapping strings (see #595)
ref.GO = obj.Type()
ref.CastType = underlying
underlyingRef, err := b.TypeReference(&ast.Type{NamedType: "String"}, nil)
if err != nil {
return nil, err
}
ref.Marshaler = underlyingRef.Marshaler
ref.Unmarshaler = underlyingRef.Unmarshaler
} else {
ref.GO = obj.Type()
}
ref.Target = ref.GO
ref.GO = b.CopyModifiersFromAst(schemaType, ref.GO)
if bindTarget != nil {
if err = code.CompatibleTypes(ref.GO, bindTarget); err != nil {
continue
}
ref.GO = bindTarget
}
ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
return ref, nil
}
return nil, fmt.Errorf("%s is incompatible with %s", schemaType.Name(), bindTarget.String())
}
func isValid(t types.Type) bool {
basic, isBasic := t.(*types.Basic)
if !isBasic {
return true
}
return basic.Kind() != types.Invalid
}
func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
if t.Elem != nil {
child := b.CopyModifiersFromAst(t.Elem, base)
if _, isStruct := child.Underlying().(*types.Struct); isStruct && !b.cfg.OmitSliceElementPointers {
child = types.NewPointer(child)
}
return types.NewSlice(child)
}
var isInterface bool
if named, ok := base.(*types.Named); ok {
_, isInterface = named.Underlying().(*types.Interface)
}
if !isInterface && !IsNilable(base) && !t.NonNull {
return types.NewPointer(base)
}
return base
}
func IsNilable(t types.Type) bool {
if namedType, isNamed := t.(*types.Named); isNamed {
return IsNilable(namedType.Underlying())
}
_, isPtr := t.(*types.Pointer)
_, isMap := t.(*types.Map)
_, isInterface := t.(*types.Interface)
_, isSlice := t.(*types.Slice)
_, isChan := t.(*types.Chan)
return isPtr || isMap || isInterface || isSlice || isChan
}
func hasMethod(it types.Type, name string) bool {
if ptr, isPtr := it.(*types.Pointer); isPtr {
it = ptr.Elem()
}
namedType, ok := it.(*types.Named)
if !ok {
return false
}
for i := 0; i < namedType.NumMethods(); i++ {
if namedType.Method(i).Name() == name {
return true
}
}
return false
}
func basicUnderlying(it types.Type) *types.Basic {
if ptr, isPtr := it.(*types.Pointer); isPtr {
it = ptr.Elem()
}
namedType, ok := it.(*types.Named)
if !ok {
return nil
}
if basic, ok := namedType.Underlying().(*types.Basic); ok {
return basic
}
return nil
}
var pkgReplacer = strings.NewReplacer(
"/", "ᚋ",
".", "ᚗ",
"-", "ᚑ",
"~", "א",
)
func TypeIdentifier(t types.Type) string {
res := ""
for {
switch it := t.(type) {
case *types.Pointer:
t.Underlying()
res += "ᚖ"
t = it.Elem()
case *types.Slice:
res += "ᚕ"
t = it.Elem()
case *types.Named:
res += pkgReplacer.Replace(it.Obj().Pkg().Path())
res += "ᚐ"
res += it.Obj().Name()
return res
case *types.Basic:
res += it.Name()
return res
case *types.Map:
res += "map"
return res
case *types.Interface:
res += "interface"
return res
default:
panic(fmt.Errorf("unexpected type %T", it))
}
}
}

View File

@@ -0,0 +1,704 @@
package config
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/v2"
"github.com/vektah/gqlparser/v2/ast"
"gopkg.in/yaml.v3"
)
type Config struct {
SchemaFilename StringList `yaml:"schema,omitempty"`
Exec ExecConfig `yaml:"exec"`
Model PackageConfig `yaml:"model,omitempty"`
Federation PackageConfig `yaml:"federation,omitempty"`
Resolver ResolverConfig `yaml:"resolver,omitempty"`
AutoBind []string `yaml:"autobind"`
Models TypeMap `yaml:"models,omitempty"`
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
GoInitialisms GoInitialismsConfig `yaml:"go_initialisms,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
OmitGetters bool `yaml:"omit_getters,omitempty"`
OmitInterfaceChecks bool `yaml:"omit_interface_checks,omitempty"`
OmitComplexity bool `yaml:"omit_complexity,omitempty"`
OmitGQLGenFileNotice bool `yaml:"omit_gqlgen_file_notice,omitempty"`
OmitGQLGenVersionInFileNotice bool `yaml:"omit_gqlgen_version_in_file_notice,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
NullableInputOmittable bool `yaml:"nullable_input_omittable,omitempty"`
EnableModelJsonOmitemptyTag *bool `yaml:"enable_model_json_omitempty_tag,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
Packages *code.Packages `yaml:"-"`
Schema *ast.Schema `yaml:"-"`
// Deprecated: use Federation instead. Will be removed next release
Federated bool `yaml:"federated,omitempty"`
}
var cfgFilenames = []string{".gqlgen.yml", "gqlgen.yml", "gqlgen.yaml"}
// DefaultConfig creates a copy of the default config
func DefaultConfig() *Config {
return &Config{
SchemaFilename: StringList{"schema.graphql"},
Model: PackageConfig{Filename: "models_gen.go"},
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ReturnPointersInUmarshalInput: false,
ResolversAlwaysReturnPointers: true,
NullableInputOmittable: false,
}
}
// LoadDefaultConfig loads the default config so that it is ready to be used
func LoadDefaultConfig() (*Config, error) {
config := DefaultConfig()
for _, filename := range config.SchemaFilename {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to open schema: %w", err)
}
config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
}
return config, nil
}
// LoadConfigFromDefaultLocations looks for a config file in the current directory, and all parent directories
// walking up the tree. The closest config file will be returned.
func LoadConfigFromDefaultLocations() (*Config, error) {
cfgFile, err := findCfg()
if err != nil {
return nil, err
}
err = os.Chdir(filepath.Dir(cfgFile))
if err != nil {
return nil, fmt.Errorf("unable to enter config dir: %w", err)
}
return LoadConfig(cfgFile)
}
var path2regex = strings.NewReplacer(
`.`, `\.`,
`*`, `.+`,
`\`, `[\\/]`,
`/`, `[\\/]`,
)
// LoadConfig reads the gqlgen.yml config file
func LoadConfig(filename string) (*Config, error) {
b, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to read config: %w", err)
}
return ReadConfig(bytes.NewReader(b))
}
func ReadConfig(cfgFile io.Reader) (*Config, error) {
config := DefaultConfig()
dec := yaml.NewDecoder(cfgFile)
dec.KnownFields(true)
if err := dec.Decode(config); err != nil {
return nil, fmt.Errorf("unable to parse config: %w", err)
}
if err := CompleteConfig(config); err != nil {
return nil, err
}
return config, nil
}
// CompleteConfig fills in the schema and other values to a config loaded from
// YAML.
func CompleteConfig(config *Config) error {
defaultDirectives := map[string]DirectiveConfig{
"skip": {SkipRuntime: true},
"include": {SkipRuntime: true},
"deprecated": {SkipRuntime: true},
"specifiedBy": {SkipRuntime: true},
}
for key, value := range defaultDirectives {
if _, defined := config.Directives[key]; !defined {
config.Directives[key] = value
}
}
preGlobbing := config.SchemaFilename
config.SchemaFilename = StringList{}
for _, f := range preGlobbing {
var matches []string
// for ** we want to override default globbing patterns and walk all
// subdirectories to match schema files.
if strings.Contains(f, "**") {
pathParts := strings.SplitN(f, "**", 2)
rest := strings.TrimPrefix(strings.TrimPrefix(pathParts[1], `\`), `/`)
// turn the rest of the glob into a regex, anchored only at the end because ** allows
// for any number of dirs in between and walk will let us match against the full path name
globRe := regexp.MustCompile(path2regex.Replace(rest) + `$`)
if err := filepath.Walk(pathParts[0], func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if globRe.MatchString(strings.TrimPrefix(path, pathParts[0])) {
matches = append(matches, path)
}
return nil
}); err != nil {
return fmt.Errorf("failed to walk schema at root %s: %w", pathParts[0], err)
}
} else {
var err error
matches, err = filepath.Glob(f)
if err != nil {
return fmt.Errorf("failed to glob schema filename %s: %w", f, err)
}
}
for _, m := range matches {
if config.SchemaFilename.Has(m) {
continue
}
config.SchemaFilename = append(config.SchemaFilename, m)
}
}
for _, filename := range config.SchemaFilename {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = os.ReadFile(filename)
if err != nil {
return fmt.Errorf("unable to open schema: %w", err)
}
config.Sources = append(config.Sources, &ast.Source{Name: filename, Input: string(schemaRaw)})
}
config.GoInitialisms.setInitialisms()
return nil
}
func (c *Config) Init() error {
if c.Packages == nil {
c.Packages = &code.Packages{}
}
if c.Schema == nil {
if err := c.LoadSchema(); err != nil {
return err
}
}
err := c.injectTypesFromSchema()
if err != nil {
return err
}
err = c.autobind()
if err != nil {
return err
}
c.injectBuiltins()
// prefetch all packages in one big packages.Load call
c.Packages.LoadAll(c.packageList()...)
// check everything is valid on the way out
err = c.check()
if err != nil {
return err
}
return nil
}
func (c *Config) packageList() []string {
pkgs := []string{
"github.com/99designs/gqlgen/graphql",
"github.com/99designs/gqlgen/graphql/introspection",
}
pkgs = append(pkgs, c.Models.ReferencedPackages()...)
pkgs = append(pkgs, c.AutoBind...)
return pkgs
}
func (c *Config) ReloadAllPackages() {
c.Packages.ReloadAll(c.packageList()...)
}
func (c *Config) injectTypesFromSchema() error {
c.Directives["goModel"] = DirectiveConfig{
SkipRuntime: true,
}
c.Directives["goField"] = DirectiveConfig{
SkipRuntime: true,
}
c.Directives["goTag"] = DirectiveConfig{
SkipRuntime: true,
}
for _, schemaType := range c.Schema.Types {
if schemaType == c.Schema.Query || schemaType == c.Schema.Mutation || schemaType == c.Schema.Subscription {
continue
}
if bd := schemaType.Directives.ForName("goModel"); bd != nil {
if ma := bd.Arguments.ForName("model"); ma != nil {
if mv, err := ma.Value.Value(nil); err == nil {
c.Models.Add(schemaType.Name, mv.(string))
}
}
if ma := bd.Arguments.ForName("models"); ma != nil {
if mvs, err := ma.Value.Value(nil); err == nil {
for _, mv := range mvs.([]interface{}) {
c.Models.Add(schemaType.Name, mv.(string))
}
}
}
}
if schemaType.Kind == ast.Object || schemaType.Kind == ast.InputObject {
for _, field := range schemaType.Fields {
if fd := field.Directives.ForName("goField"); fd != nil {
forceResolver := c.Models[schemaType.Name].Fields[field.Name].Resolver
fieldName := c.Models[schemaType.Name].Fields[field.Name].FieldName
if ra := fd.Arguments.ForName("forceResolver"); ra != nil {
if fr, err := ra.Value.Value(nil); err == nil {
forceResolver = fr.(bool)
}
}
if na := fd.Arguments.ForName("name"); na != nil {
if fr, err := na.Value.Value(nil); err == nil {
fieldName = fr.(string)
}
}
if c.Models[schemaType.Name].Fields == nil {
c.Models[schemaType.Name] = TypeMapEntry{
Model: c.Models[schemaType.Name].Model,
ExtraFields: c.Models[schemaType.Name].ExtraFields,
Fields: map[string]TypeMapField{},
}
}
c.Models[schemaType.Name].Fields[field.Name] = TypeMapField{
FieldName: fieldName,
Resolver: forceResolver,
}
}
}
}
}
return nil
}
type TypeMapEntry struct {
Model StringList `yaml:"model"`
Fields map[string]TypeMapField `yaml:"fields,omitempty"`
// Key is the Go name of the field.
ExtraFields map[string]ModelExtraField `yaml:"extraFields,omitempty"`
}
type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
GeneratedMethod string `yaml:"-"`
}
type ModelExtraField struct {
// Type is the Go type of the field.
//
// It supports the builtin basic types (like string or int64), named types
// (qualified by the full package path), pointers to those types (prefixed
// with `*`), and slices of those types (prefixed with `[]`).
//
// For example, the following are valid types:
// string
// *github.com/author/package.Type
// []string
// []*github.com/author/package.Type
//
// Note that the type will be referenced from the generated/graphql, which
// means the package it lives in must not reference the generated/graphql
// package to avoid circular imports.
// restrictions.
Type string `yaml:"type"`
// OverrideTags is an optional override of the Go field tag.
OverrideTags string `yaml:"overrideTags"`
// Description is an optional the Go field doc-comment.
Description string `yaml:"description"`
}
type StringList []string
func (a *StringList) UnmarshalYAML(unmarshal func(interface{}) error) error {
var single string
err := unmarshal(&single)
if err == nil {
*a = []string{single}
return nil
}
var multi []string
err = unmarshal(&multi)
if err != nil {
return err
}
*a = multi
return nil
}
func (a StringList) Has(file string) bool {
for _, existing := range a {
if existing == file {
return true
}
}
return false
}
func (c *Config) check() error {
if c.Models == nil {
c.Models = TypeMap{}
}
type FilenamePackage struct {
Filename string
Package string
Declaree string
}
fileList := map[string][]FilenamePackage{}
if err := c.Models.Check(); err != nil {
return fmt.Errorf("config.models: %w", err)
}
if err := c.Exec.Check(); err != nil {
return fmt.Errorf("config.exec: %w", err)
}
fileList[c.Exec.ImportPath()] = append(fileList[c.Exec.ImportPath()], FilenamePackage{
Filename: c.Exec.Filename,
Package: c.Exec.Package,
Declaree: "exec",
})
if c.Model.IsDefined() {
if err := c.Model.Check(); err != nil {
return fmt.Errorf("config.model: %w", err)
}
fileList[c.Model.ImportPath()] = append(fileList[c.Model.ImportPath()], FilenamePackage{
Filename: c.Model.Filename,
Package: c.Model.Package,
Declaree: "model",
})
}
if c.Resolver.IsDefined() {
if err := c.Resolver.Check(); err != nil {
return fmt.Errorf("config.resolver: %w", err)
}
fileList[c.Resolver.ImportPath()] = append(fileList[c.Resolver.ImportPath()], FilenamePackage{
Filename: c.Resolver.Filename,
Package: c.Resolver.Package,
Declaree: "resolver",
})
}
if c.Federation.IsDefined() {
if err := c.Federation.Check(); err != nil {
return fmt.Errorf("config.federation: %w", err)
}
fileList[c.Federation.ImportPath()] = append(fileList[c.Federation.ImportPath()], FilenamePackage{
Filename: c.Federation.Filename,
Package: c.Federation.Package,
Declaree: "federation",
})
if c.Federation.ImportPath() != c.Exec.ImportPath() {
return fmt.Errorf("federation and exec must be in the same package")
}
}
if c.Federated {
return fmt.Errorf("federated has been removed, instead use\nfederation:\n filename: path/to/federated.go")
}
for importPath, pkg := range fileList {
for _, file1 := range pkg {
for _, file2 := range pkg {
if file1.Package != file2.Package {
return fmt.Errorf("%s and %s define the same import path (%s) with different package names (%s vs %s)",
file1.Declaree,
file2.Declaree,
importPath,
file1.Package,
file2.Package,
)
}
}
}
}
return nil
}
type TypeMap map[string]TypeMapEntry
func (tm TypeMap) Exists(typeName string) bool {
_, ok := tm[typeName]
return ok
}
func (tm TypeMap) UserDefined(typeName string) bool {
m, ok := tm[typeName]
return ok && len(m.Model) > 0
}
func (tm TypeMap) Check() error {
for typeName, entry := range tm {
for _, model := range entry.Model {
if strings.LastIndex(model, ".") < strings.LastIndex(model, "/") {
return fmt.Errorf("model %s: invalid type specifier \"%s\" - you need to specify a struct to map to", typeName, entry.Model)
}
}
}
return nil
}
func (tm TypeMap) ReferencedPackages() []string {
var pkgs []string
for _, typ := range tm {
for _, model := range typ.Model {
if model == "map[string]interface{}" || model == "interface{}" {
continue
}
pkg, _ := code.PkgAndType(model)
if pkg == "" || inStrSlice(pkgs, pkg) {
continue
}
pkgs = append(pkgs, code.QualifyPackagePath(pkg))
}
}
sort.Slice(pkgs, func(i, j int) bool {
return pkgs[i] > pkgs[j]
})
return pkgs
}
func (tm TypeMap) Add(name string, goType string) {
modelCfg := tm[name]
modelCfg.Model = append(modelCfg.Model, goType)
tm[name] = modelCfg
}
type DirectiveConfig struct {
SkipRuntime bool `yaml:"skip_runtime"`
}
func inStrSlice(haystack []string, needle string) bool {
for _, v := range haystack {
if needle == v {
return true
}
}
return false
}
// findCfg searches for the config file in this directory and all parents up the tree
// looking for the closest match
func findCfg() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("unable to get working dir to findCfg: %w", err)
}
cfg := findCfgInDir(dir)
for cfg == "" && dir != filepath.Dir(dir) {
dir = filepath.Dir(dir)
cfg = findCfgInDir(dir)
}
if cfg == "" {
return "", os.ErrNotExist
}
return cfg, nil
}
func findCfgInDir(dir string) string {
for _, cfgName := range cfgFilenames {
path := filepath.Join(dir, cfgName)
if _, err := os.Stat(path); err == nil {
return path
}
}
return ""
}
func (c *Config) autobind() error {
if len(c.AutoBind) == 0 {
return nil
}
ps := c.Packages.LoadAll(c.AutoBind...)
for _, t := range c.Schema.Types {
if c.Models.UserDefined(t.Name) {
continue
}
for i, p := range ps {
if p == nil || p.Module == nil {
return fmt.Errorf("unable to load %s - make sure you're using an import path to a package that exists", c.AutoBind[i])
}
if t := p.Types.Scope().Lookup(t.Name); t != nil {
c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
break
}
}
}
for i, t := range c.Models {
for j, m := range t.Model {
pkg, typename := code.PkgAndType(m)
// skip anything that looks like an import path
if strings.Contains(pkg, "/") {
continue
}
for _, p := range ps {
if p.Name != pkg {
continue
}
if t := p.Types.Scope().Lookup(typename); t != nil {
c.Models[i].Model[j] = t.Pkg().Path() + "." + t.Name()
break
}
}
}
}
return nil
}
func (c *Config) injectBuiltins() {
builtins := TypeMap{
"__Directive": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Directive"}},
"__DirectiveLocation": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"__Type": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Type"}},
"__TypeKind": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"__Field": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Field"}},
"__EnumValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.EnumValue"}},
"__InputValue": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.InputValue"}},
"__Schema": {Model: StringList{"github.com/99designs/gqlgen/graphql/introspection.Schema"}},
"Float": {Model: StringList{"github.com/99designs/gqlgen/graphql.FloatContext"}},
"String": {Model: StringList{"github.com/99designs/gqlgen/graphql.String"}},
"Boolean": {Model: StringList{"github.com/99designs/gqlgen/graphql.Boolean"}},
"Int": {Model: StringList{
"github.com/99designs/gqlgen/graphql.Int",
"github.com/99designs/gqlgen/graphql.Int32",
"github.com/99designs/gqlgen/graphql.Int64",
}},
"ID": {
Model: StringList{
"github.com/99designs/gqlgen/graphql.ID",
"github.com/99designs/gqlgen/graphql.IntID",
},
},
}
for typeName, entry := range builtins {
if !c.Models.Exists(typeName) {
c.Models[typeName] = entry
}
}
// These are additional types that are injected if defined in the schema as scalars.
extraBuiltins := TypeMap{
"Time": {Model: StringList{"github.com/99designs/gqlgen/graphql.Time"}},
"Map": {Model: StringList{"github.com/99designs/gqlgen/graphql.Map"}},
"Upload": {Model: StringList{"github.com/99designs/gqlgen/graphql.Upload"}},
"Any": {Model: StringList{"github.com/99designs/gqlgen/graphql.Any"}},
}
for typeName, entry := range extraBuiltins {
if t, ok := c.Schema.Types[typeName]; !c.Models.Exists(typeName) && ok && t.Kind == ast.Scalar {
c.Models[typeName] = entry
}
}
}
func (c *Config) LoadSchema() error {
if c.Packages != nil {
c.Packages = &code.Packages{}
}
if err := c.check(); err != nil {
return err
}
schema, err := gqlparser.LoadSchema(c.Sources...)
if err != nil {
return err
}
if schema.Query == nil {
schema.Query = &ast.Definition{
Kind: ast.Object,
Name: "Query",
}
schema.Types["Query"] = schema.Query
}
c.Schema = schema
return nil
}
func abs(path string) string {
absPath, err := filepath.Abs(path)
if err != nil {
panic(err)
}
return filepath.ToSlash(absPath)
}

View File

@@ -0,0 +1,97 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type ExecConfig struct {
Package string `yaml:"package,omitempty"`
Layout ExecLayout `yaml:"layout,omitempty"` // Default: single-file
// Only for single-file layout:
Filename string `yaml:"filename,omitempty"`
// Only for follow-schema layout:
FilenameTemplate string `yaml:"filename_template,omitempty"` // String template with {name} as placeholder for base name.
DirName string `yaml:"dir"`
}
type ExecLayout string
var (
// Write all generated code to a single file.
ExecLayoutSingleFile ExecLayout = "single-file"
// Write generated code to a directory, generating one Go source file for each GraphQL schema file.
ExecLayoutFollowSchema ExecLayout = "follow-schema"
)
func (r *ExecConfig) Check() error {
if r.Layout == "" {
r.Layout = ExecLayoutSingleFile
}
switch r.Layout {
case ExecLayoutSingleFile:
if r.Filename == "" {
return fmt.Errorf("filename must be specified when using single-file layout")
}
if !strings.HasSuffix(r.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file when using single-file layout")
}
r.Filename = abs(r.Filename)
case ExecLayoutFollowSchema:
if r.DirName == "" {
return fmt.Errorf("dir must be specified when using follow-schema layout")
}
r.DirName = abs(r.DirName)
default:
return fmt.Errorf("invalid layout %s", r.Layout)
}
if strings.ContainsAny(r.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if r.Package == "" && r.Dir() != "" {
r.Package = code.NameForDir(r.Dir())
}
return nil
}
func (r *ExecConfig) ImportPath() string {
if r.Dir() == "" {
return ""
}
return code.ImportPathForDir(r.Dir())
}
func (r *ExecConfig) Dir() string {
switch r.Layout {
case ExecLayoutSingleFile:
if r.Filename == "" {
return ""
}
return filepath.Dir(r.Filename)
case ExecLayoutFollowSchema:
return abs(r.DirName)
default:
panic("invalid layout " + r.Layout)
}
}
func (r *ExecConfig) Pkg() *types.Package {
if r.Dir() == "" {
return nil
}
return types.NewPackage(r.ImportPath(), r.Package)
}
func (r *ExecConfig) IsDefined() bool {
return r.Filename != "" || r.DirName != ""
}

View File

@@ -0,0 +1,94 @@
package config
import "strings"
// commonInitialisms is a set of common initialisms.
// Only add entries that are highly unlikely to be non-initialisms.
// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
var commonInitialisms = map[string]bool{
"ACL": true,
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"CSV": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ICMP": true,
"ID": true,
"IP": true,
"JSON": true,
"KVK": true,
"LHS": true,
"PDF": true,
"PGP": true,
"QPS": true,
"QR": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SQL": true,
"SSH": true,
"SVG": true,
"TCP": true,
"TLS": true,
"TTL": true,
"UDP": true,
"UI": true,
"UID": true,
"URI": true,
"URL": true,
"UTF8": true,
"UUID": true,
"VM": true,
"XML": true,
"XMPP": true,
"XSRF": true,
"XSS": true,
}
// GetInitialisms returns the initialisms to capitalize in Go names. If unchanged, default initialisms will be returned
var GetInitialisms = func() map[string]bool {
return commonInitialisms
}
// GoInitialismsConfig allows to modify the default behavior of naming Go methods, types and properties
type GoInitialismsConfig struct {
// If true, the Initialisms won't get appended to the default ones but replace them
ReplaceDefaults bool `yaml:"replace_defaults"`
// Custom initialisms to be added or to replace the default ones
Initialisms []string `yaml:"initialisms"`
}
// setInitialisms adjustes GetInitialisms based on its settings.
func (i GoInitialismsConfig) setInitialisms() {
toUse := i.determineGoInitialisms()
GetInitialisms = func() map[string]bool {
return toUse
}
}
// determineGoInitialisms returns the Go initialims to be used, based on its settings.
func (i GoInitialismsConfig) determineGoInitialisms() (initialismsToUse map[string]bool) {
if i.ReplaceDefaults {
initialismsToUse = make(map[string]bool, len(i.Initialisms))
for _, initialism := range i.Initialisms {
initialismsToUse[strings.ToUpper(initialism)] = true
}
} else {
initialismsToUse = make(map[string]bool, len(commonInitialisms)+len(i.Initialisms))
for initialism, value := range commonInitialisms {
initialismsToUse[strings.ToUpper(initialism)] = value
}
for _, initialism := range i.Initialisms {
initialismsToUse[strings.ToUpper(initialism)] = true
}
}
return initialismsToUse
}

View File

@@ -0,0 +1,63 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type PackageConfig struct {
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
Version int `yaml:"version,omitempty"`
}
func (c *PackageConfig) ImportPath() string {
if !c.IsDefined() {
return ""
}
return code.ImportPathForDir(c.Dir())
}
func (c *PackageConfig) Dir() string {
if !c.IsDefined() {
return ""
}
return filepath.Dir(c.Filename)
}
func (c *PackageConfig) Pkg() *types.Package {
if !c.IsDefined() {
return nil
}
return types.NewPackage(c.ImportPath(), c.Package)
}
func (c *PackageConfig) IsDefined() bool {
return c.Filename != ""
}
func (c *PackageConfig) Check() error {
if strings.ContainsAny(c.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if c.Filename == "" {
return fmt.Errorf("filename must be specified")
}
if !strings.HasSuffix(c.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file")
}
c.Filename = abs(c.Filename)
// If Package is not set, first attempt to load the package at the output dir. If that fails
// fallback to just the base dir name of the output filename.
if c.Package == "" {
c.Package = code.NameForDir(c.Dir())
}
return nil
}

View File

@@ -0,0 +1,101 @@
package config
import (
"fmt"
"go/types"
"path/filepath"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type ResolverConfig struct {
Filename string `yaml:"filename,omitempty"`
FilenameTemplate string `yaml:"filename_template,omitempty"`
Package string `yaml:"package,omitempty"`
Type string `yaml:"type,omitempty"`
Layout ResolverLayout `yaml:"layout,omitempty"`
DirName string `yaml:"dir"`
OmitTemplateComment bool `yaml:"omit_template_comment,omitempty"`
}
type ResolverLayout string
var (
LayoutSingleFile ResolverLayout = "single-file"
LayoutFollowSchema ResolverLayout = "follow-schema"
)
func (r *ResolverConfig) Check() error {
if r.Layout == "" {
r.Layout = LayoutSingleFile
}
if r.Type == "" {
r.Type = "Resolver"
}
switch r.Layout {
case LayoutSingleFile:
if r.Filename == "" {
return fmt.Errorf("filename must be specified with layout=%s", r.Layout)
}
if !strings.HasSuffix(r.Filename, ".go") {
return fmt.Errorf("filename should be path to a go source file with layout=%s", r.Layout)
}
r.Filename = abs(r.Filename)
case LayoutFollowSchema:
if r.DirName == "" {
return fmt.Errorf("dirname must be specified with layout=%s", r.Layout)
}
r.DirName = abs(r.DirName)
if r.Filename == "" {
r.Filename = filepath.Join(r.DirName, "resolver.go")
} else {
r.Filename = abs(r.Filename)
}
default:
return fmt.Errorf("invalid layout %s. must be %s or %s", r.Layout, LayoutSingleFile, LayoutFollowSchema)
}
if strings.ContainsAny(r.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
if r.Package == "" && r.Dir() != "" {
r.Package = code.NameForDir(r.Dir())
}
return nil
}
func (r *ResolverConfig) ImportPath() string {
if r.Dir() == "" {
return ""
}
return code.ImportPathForDir(r.Dir())
}
func (r *ResolverConfig) Dir() string {
switch r.Layout {
case LayoutSingleFile:
if r.Filename == "" {
return ""
}
return filepath.Dir(r.Filename)
case LayoutFollowSchema:
return r.DirName
default:
panic("invalid layout " + r.Layout)
}
}
func (r *ResolverConfig) Pkg() *types.Package {
if r.Dir() == "" {
return nil
}
return types.NewPackage(r.ImportPath(), r.Package)
}
func (r *ResolverConfig) IsDefined() bool {
return r.Filename != "" || r.DirName != ""
}

View File

@@ -0,0 +1,235 @@
package codegen
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen/config"
)
// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
// resolvers or directives automatically (eg grpc, validation)
type Data struct {
Config *config.Config
Schema *ast.Schema
// If a schema is broken up into multiple Data instance, each representing part of the schema,
// AllDirectives should contain the directives for the entire schema. Directives() can
// then be used to get the directives that were defined in this Data instance's sources.
// If a single Data instance is used for the entire schema, AllDirectives and Directives()
// will be identical.
// AllDirectives should rarely be used directly.
AllDirectives DirectiveList
Objects Objects
Inputs Objects
Interfaces map[string]*Interface
ReferencedTypes map[string]*config.TypeReference
ComplexityRoots map[string]*Object
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
AugmentedSources []AugmentedSource
Plugins []interface{}
}
func (d *Data) HasEmbeddableSources() bool {
hasEmbeddableSources := false
for _, s := range d.AugmentedSources {
if s.Embeddable {
hasEmbeddableSources = true
}
}
return hasEmbeddableSources
}
// AugmentedSource contains extra information about graphql schema files which is not known directly from the Config.Sources data
type AugmentedSource struct {
// path relative to Config.Exec.Filename
RelativePath string
Embeddable bool
BuiltIn bool
Source string
}
type builder struct {
Config *config.Config
Schema *ast.Schema
Binder *config.Binder
Directives map[string]*Directive
}
// Get only the directives which are defined in the config's sources.
func (d *Data) Directives() DirectiveList {
res := DirectiveList{}
for k, directive := range d.AllDirectives {
for _, s := range d.Config.Sources {
if directive.Position.Src.Name == s.Name {
res[k] = directive
break
}
}
}
return res
}
func BuildData(cfg *config.Config, plugins ...interface{}) (*Data, error) {
// We reload all packages to allow packages to be compared correctly.
cfg.ReloadAllPackages()
b := builder{
Config: cfg,
Schema: cfg.Schema,
}
b.Binder = b.Config.NewBinder()
var err error
b.Directives, err = b.buildDirectives()
if err != nil {
return nil, err
}
dataDirectives := make(map[string]*Directive)
for name, d := range b.Directives {
if !d.Builtin {
dataDirectives[name] = d
}
}
s := Data{
Config: cfg,
AllDirectives: dataDirectives,
Schema: b.Schema,
Interfaces: map[string]*Interface{},
Plugins: plugins,
}
for _, schemaType := range b.Schema.Types {
switch schemaType.Kind {
case ast.Object:
obj, err := b.buildObject(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to build object definition: %w", err)
}
s.Objects = append(s.Objects, obj)
case ast.InputObject:
input, err := b.buildObject(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to build input definition: %w", err)
}
s.Inputs = append(s.Inputs, input)
case ast.Union, ast.Interface:
s.Interfaces[schemaType.Name], err = b.buildInterface(schemaType)
if err != nil {
return nil, fmt.Errorf("unable to bind to interface: %w", err)
}
}
}
if s.Schema.Query != nil {
s.QueryRoot = s.Objects.ByName(s.Schema.Query.Name)
} else {
return nil, fmt.Errorf("query entry point missing")
}
if s.Schema.Mutation != nil {
s.MutationRoot = s.Objects.ByName(s.Schema.Mutation.Name)
}
if s.Schema.Subscription != nil {
s.SubscriptionRoot = s.Objects.ByName(s.Schema.Subscription.Name)
}
if err := b.injectIntrospectionRoots(&s); err != nil {
return nil, err
}
s.ReferencedTypes = b.buildTypes()
sort.Slice(s.Objects, func(i, j int) bool {
return s.Objects[i].Definition.Name < s.Objects[j].Definition.Name
})
sort.Slice(s.Inputs, func(i, j int) bool {
return s.Inputs[i].Definition.Name < s.Inputs[j].Definition.Name
})
if b.Binder.SawInvalid {
// if we have a syntax error, show it
err := cfg.Packages.Errors()
if len(err) > 0 {
return nil, err
}
// otherwise show a generic error message
return nil, fmt.Errorf("invalid types were encountered while traversing the go source code, this probably means the invalid code generated isnt correct. add try adding -v to debug")
}
aSources := []AugmentedSource{}
for _, s := range cfg.Sources {
wd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("failed to get working directory: %w", err)
}
outputDir := cfg.Exec.Dir()
sourcePath := filepath.Join(wd, s.Name)
relative, err := filepath.Rel(outputDir, sourcePath)
if err != nil {
return nil, fmt.Errorf("failed to compute path of %s relative to %s: %w", sourcePath, outputDir, err)
}
relative = filepath.ToSlash(relative)
embeddable := true
if strings.HasPrefix(relative, "..") || s.BuiltIn {
embeddable = false
}
aSources = append(aSources, AugmentedSource{
RelativePath: relative,
Embeddable: embeddable,
BuiltIn: s.BuiltIn,
Source: s.Input,
})
}
s.AugmentedSources = aSources
return &s, nil
}
func (b *builder) injectIntrospectionRoots(s *Data) error {
obj := s.Objects.ByName(b.Schema.Query.Name)
if obj == nil {
return fmt.Errorf("root query type must be defined")
}
__type, err := b.buildField(obj, &ast.FieldDefinition{
Name: "__type",
Type: ast.NamedType("__Type", nil),
Arguments: []*ast.ArgumentDefinition{
{
Name: "name",
Type: ast.NonNullNamedType("String", nil),
},
},
})
if err != nil {
return err
}
__schema, err := b.buildField(obj, &ast.FieldDefinition{
Name: "__schema",
Type: ast.NamedType("__Schema", nil),
})
if err != nil {
return err
}
obj.Fields = append(obj.Fields, __type, __schema)
return nil
}

View File

@@ -0,0 +1,174 @@
package codegen
import (
"fmt"
"strconv"
"strings"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
type DirectiveList map[string]*Directive
// LocationDirectives filter directives by location
func (dl DirectiveList) LocationDirectives(location string) DirectiveList {
return locationDirectives(dl, ast.DirectiveLocation(location))
}
type Directive struct {
*ast.DirectiveDefinition
Name string
Args []*FieldArgument
Builtin bool
}
// IsLocation check location directive
func (d *Directive) IsLocation(location ...ast.DirectiveLocation) bool {
for _, l := range d.Locations {
for _, a := range location {
if l == a {
return true
}
}
}
return false
}
func locationDirectives(directives DirectiveList, location ...ast.DirectiveLocation) map[string]*Directive {
mDirectives := make(map[string]*Directive)
for name, d := range directives {
if d.IsLocation(location...) {
mDirectives[name] = d
}
}
return mDirectives
}
func (b *builder) buildDirectives() (map[string]*Directive, error) {
directives := make(map[string]*Directive, len(b.Schema.Directives))
for name, dir := range b.Schema.Directives {
if _, ok := directives[name]; ok {
return nil, fmt.Errorf("directive with name %s already exists", name)
}
var args []*FieldArgument
for _, arg := range dir.Arguments {
tr, err := b.Binder.TypeReference(arg.Type, nil)
if err != nil {
return nil, err
}
newArg := &FieldArgument{
ArgumentDefinition: arg,
TypeReference: tr,
VarName: templates.ToGoPrivate(arg.Name),
}
if arg.DefaultValue != nil {
var err error
newArg.Default, err = arg.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value for directive argument %s(%s) is not valid: %w", dir.Name, arg.Name, err)
}
}
args = append(args, newArg)
}
directives[name] = &Directive{
DirectiveDefinition: dir,
Name: name,
Args: args,
Builtin: b.Config.Directives[name].SkipRuntime,
}
}
return directives, nil
}
func (b *builder) getDirectives(list ast.DirectiveList) ([]*Directive, error) {
dirs := make([]*Directive, len(list))
for i, d := range list {
argValues := make(map[string]interface{}, len(d.Arguments))
for _, da := range d.Arguments {
val, err := da.Value.Value(nil)
if err != nil {
return nil, err
}
argValues[da.Name] = val
}
def, ok := b.Directives[d.Name]
if !ok {
return nil, fmt.Errorf("directive %s not found", d.Name)
}
var args []*FieldArgument
for _, a := range def.Args {
value := a.Default
if argValue, ok := argValues[a.Name]; ok {
value = argValue
}
args = append(args, &FieldArgument{
ArgumentDefinition: a.ArgumentDefinition,
Value: value,
VarName: a.VarName,
TypeReference: a.TypeReference,
})
}
dirs[i] = &Directive{
Name: d.Name,
Args: args,
DirectiveDefinition: list[i].Definition,
Builtin: b.Config.Directives[d.Name].SkipRuntime,
}
}
return dirs, nil
}
func (d *Directive) ArgsFunc() string {
if len(d.Args) == 0 {
return ""
}
return "dir_" + d.Name + "_args"
}
func (d *Directive) CallArgs() string {
args := []string{"ctx", "obj", "n"}
for _, arg := range d.Args {
args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
}
return strings.Join(args, ", ")
}
func (d *Directive) ResolveArgs(obj string, next int) string {
args := []string{"ctx", obj, fmt.Sprintf("directive%d", next)}
for _, arg := range d.Args {
dArg := arg.VarName
if arg.Value == nil && arg.Default == nil {
dArg = "nil"
}
args = append(args, dArg)
}
return strings.Join(args, ", ")
}
func (d *Directive) Declaration() string {
res := ucFirst(d.Name) + " func(ctx context.Context, obj interface{}, next graphql.Resolver"
for _, arg := range d.Args {
res += fmt.Sprintf(", %s %s", templates.ToGoPrivate(arg.Name), templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
res += ") (res interface{}, err error)"
return res
}

View File

@@ -0,0 +1,149 @@
{{ define "implDirectives" }}{{ $in := .DirectiveObjName }}
{{- range $i, $directive := .ImplDirectives -}}
directive{{add $i 1}} := func(ctx context.Context) (interface{}, error) {
{{- range $arg := $directive.Args }}
{{- if notNil "Value" $arg }}
{{ $arg.VarName }}, err := ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, {{ $arg.Value | dump }})
if err != nil{
return nil, err
}
{{- else if notNil "Default" $arg }}
{{ $arg.VarName }}, err := ec.{{ $arg.TypeReference.UnmarshalFunc }}(ctx, {{ $arg.Default | dump }})
if err != nil{
return nil, err
}
{{- end }}
{{- end }}
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.ResolveArgs $in $i }})
}
{{ end -}}
{{ end }}
{{define "queryDirectives"}}
for _, d := range obj.Directives {
switch d.Name {
{{- range $directive := . }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
tmp, err := next(ctx)
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if data, ok := tmp.(graphql.Marshaler); ok {
return data
}
ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp)
return graphql.Null
{{end}}
{{ if .Directives.LocationDirectives "QUERY" }}
func (ec *executionContext) _queryMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) graphql.Marshaler {
{{ template "queryDirectives" .Directives.LocationDirectives "QUERY" }}
}
{{ end }}
{{ if .Directives.LocationDirectives "MUTATION" }}
func (ec *executionContext) _mutationMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) graphql.Marshaler {
{{ template "queryDirectives" .Directives.LocationDirectives "MUTATION" }}
}
{{ end }}
{{ if .Directives.LocationDirectives "SUBSCRIPTION" }}
func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func(ctx context.Context) graphql.Marshaler {
for _, d := range obj.Directives {
switch d.Name {
{{- range $directive := .Directives.LocationDirectives "SUBSCRIPTION" }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
tmp, err := next(ctx)
if err != nil {
ec.Error(ctx, err)
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}
if data, ok := tmp.(func(ctx context.Context) graphql.Marshaler); ok {
return data
}
ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp)
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}
{{ end }}
{{ if .Directives.LocationDirectives "FIELD" }}
func (ec *executionContext) _fieldMiddleware(ctx context.Context, obj interface{}, next graphql.Resolver) interface{} {
{{- if .Directives.LocationDirectives "FIELD" }}
fc := graphql.GetFieldContext(ctx)
for _, d := range fc.Field.Directives {
switch d.Name {
{{- range $directive := .Directives.LocationDirectives "FIELD" }}
case "{{$directive.Name}}":
{{- if $directive.Args }}
rawArgs := d.ArgumentMap(ec.Variables)
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return nil
}
{{- end }}
n := next
next = func(ctx context.Context) (interface{}, error) {
if ec.directives.{{$directive.Name|ucFirst}} == nil {
return nil, errors.New("directive {{$directive.Name}} is not implemented")
}
return ec.directives.{{$directive.Name|ucFirst}}({{$directive.CallArgs}})
}
{{- end }}
}
}
{{- end }}
res, err := ec.ResolverMiddleware(ctx, next)
if err != nil {
ec.Error(ctx, err)
return nil
}
return res
}
{{ end }}

View File

@@ -0,0 +1,611 @@
package codegen
import (
"errors"
"fmt"
goast "go/ast"
"go/types"
"log"
"reflect"
"strconv"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
type Field struct {
*ast.FieldDefinition
TypeReference *config.TypeReference
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
IsResolver bool // Does this field need a resolver
Args []*FieldArgument // A list of arguments to be passed to this field
MethodHasContext bool // If this is bound to a go method, does the method also take a context
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
VOkFunc bool // If this is bound to a go method, is it of shape (interface{}, bool)
Object *Object // A link back to the parent object
Default interface{} // The default value
Stream bool // does this field return a channel?
Directives []*Directive
}
func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
dirs, err := b.getDirectives(field.Directives)
if err != nil {
return nil, err
}
f := Field{
FieldDefinition: field,
Object: obj,
Directives: dirs,
GoFieldName: templates.ToGo(field.Name),
GoFieldType: GoFieldVariable,
GoReceiverName: "obj",
}
if field.DefaultValue != nil {
var err error
f.Default, err = field.DefaultValue.Value(nil)
if err != nil {
return nil, fmt.Errorf("default value %s is not valid: %w", field.Name, err)
}
}
for _, arg := range field.Arguments {
newArg, err := b.buildArg(obj, arg)
if err != nil {
return nil, err
}
f.Args = append(f.Args, newArg)
}
if err = b.bindField(obj, &f); err != nil {
f.IsResolver = true
if errors.Is(err, config.ErrTypeNotFound) {
return nil, err
}
log.Println(err.Error())
}
if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
f.TypeReference = b.Binder.PointerTo(f.TypeReference)
}
return &f, nil
}
func (b *builder) bindField(obj *Object, f *Field) (errret error) {
defer func() {
if f.TypeReference == nil {
tr, err := b.Binder.TypeReference(f.Type, nil)
if err != nil {
errret = err
}
f.TypeReference = tr
}
if f.TypeReference != nil {
dirs, err := b.getDirectives(f.TypeReference.Definition.Directives)
if err != nil {
errret = err
}
for _, dir := range obj.Directives {
if dir.IsLocation(ast.LocationInputObject) {
dirs = append(dirs, dir)
}
}
f.Directives = append(dirs, f.Directives...)
}
}()
f.Stream = obj.Stream
switch {
case f.Name == "__schema":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "introspectSchema"
return nil
case f.Name == "__type":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "introspectType"
return nil
case f.Name == "_entities":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "__resolve_entities"
f.MethodHasContext = true
f.NoErr = true
return nil
case f.Name == "_service":
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "ec"
f.GoFieldName = "__resolve__service"
f.MethodHasContext = true
return nil
case obj.Root:
f.IsResolver = true
return nil
case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
f.IsResolver = true
return nil
case obj.Type == config.MapType:
f.GoFieldType = GoFieldMap
return nil
case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
}
target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
if err != nil {
return err
}
pos := b.Binder.ObjectPosition(target)
switch target := target.(type) {
case nil:
objPos := b.Binder.TypePosition(obj.Type)
return fmt.Errorf(
"%s:%d adding resolver method for %s.%s, nothing matched",
objPos.Filename,
objPos.Line,
obj.Name,
f.Name,
)
case *types.Func:
sig := target.Type().(*types.Signature)
if sig.Results().Len() == 1 {
f.NoErr = true
} else if s := sig.Results(); s.Len() == 2 && s.At(1).Type().String() == "bool" {
f.VOkFunc = true
} else if sig.Results().Len() != 2 {
return fmt.Errorf("method has wrong number of args")
}
params := sig.Params()
// If the first argument is the context, remove it from the comparison and set
// the MethodHasContext flag so that the context will be passed to this model's method
if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
f.MethodHasContext = true
vars := make([]*types.Var, params.Len()-1)
for i := 1; i < params.Len(); i++ {
vars[i-1] = params.At(i)
}
params = types.NewTuple(vars...)
}
// Try to match target function's arguments with GraphQL field arguments.
newArgs, err := b.bindArgs(f, sig, params)
if err != nil {
return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
}
// Try to match target function's return types with GraphQL field return type
result := sig.Results().At(0)
tr, err := b.Binder.TypeReference(f.Type, result.Type())
if err != nil {
return err
}
// success, args and return type match. Bind to method
f.GoFieldType = GoFieldMethod
f.GoReceiverName = "obj"
f.GoFieldName = target.Name()
f.Args = newArgs
f.TypeReference = tr
return nil
case *types.Var:
tr, err := b.Binder.TypeReference(f.Type, target.Type())
if err != nil {
return err
}
// success, bind to var
f.GoFieldType = GoFieldVariable
f.GoReceiverName = "obj"
f.GoFieldName = target.Name()
f.TypeReference = tr
return nil
default:
panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
}
}
// findBindTarget attempts to match the name to a field or method on a Type
// with the following priorites:
// 1. Any Fields with a struct tag (see config.StructTag). Errors if more than one match is found
// 2. Any method or field with a matching name. Errors if more than one match is found
// 3. Same logic again for embedded fields
func (b *builder) findBindTarget(t types.Type, name string) (types.Object, error) {
// NOTE: a struct tag will override both methods and fields
// Bind to struct tag
found, err := b.findBindStructTagTarget(t, name)
if found != nil || err != nil {
return found, err
}
// Search for a method to bind to
foundMethod, err := b.findBindMethodTarget(t, name)
if err != nil {
return nil, err
}
// Search for a field to bind to
foundField, err := b.findBindFieldTarget(t, name)
if err != nil {
return nil, err
}
switch {
case foundField == nil && foundMethod != nil:
// Bind to method
return foundMethod, nil
case foundField != nil && foundMethod == nil:
// Bind to field
return foundField, nil
case foundField != nil && foundMethod != nil:
// Error
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
// Search embeds
return b.findBindEmbedsTarget(t, name)
}
func (b *builder) findBindStructTagTarget(in types.Type, name string) (types.Object, error) {
if b.Config.StructTag == "" {
return nil, nil
}
switch t := in.(type) {
case *types.Named:
return b.findBindStructTagTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || field.Embedded() {
continue
}
tags := reflect.StructTag(t.Tag(i))
if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
if found != nil {
return nil, fmt.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
}
found = field
}
}
return found, nil
}
return nil, nil
}
func (b *builder) findBindMethodTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
if _, ok := t.Underlying().(*types.Interface); ok {
return b.findBindMethodTarget(t.Underlying(), name)
}
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
case *types.Interface:
// FIX-ME: Should use ExplicitMethod here? What's the difference?
return b.findBindMethoderTarget(t.Method, t.NumMethods(), name)
}
return nil, nil
}
func (b *builder) findBindMethoderTarget(methodFunc func(i int) *types.Func, methodCount int, name string) (types.Object, error) {
var found types.Object
for i := 0; i < methodCount; i++ {
method := methodFunc(i)
if !method.Exported() || !strings.EqualFold(method.Name(), name) {
continue
}
if found != nil {
return nil, fmt.Errorf("found more than one matching method to bind for %s", name)
}
found = method
}
return found, nil
}
func (b *builder) findBindFieldTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindFieldTarget(t.Underlying(), name)
case *types.Struct:
var found types.Object
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
if !field.Exported() || !equalFieldName(field.Name(), name) {
continue
}
if found != nil {
return nil, fmt.Errorf("found more than one matching field to bind for %s", name)
}
found = field
}
return found, nil
}
return nil, nil
}
func (b *builder) findBindEmbedsTarget(in types.Type, name string) (types.Object, error) {
switch t := in.(type) {
case *types.Named:
return b.findBindEmbedsTarget(t.Underlying(), name)
case *types.Struct:
return b.findBindStructEmbedsTarget(t, name)
case *types.Interface:
return b.findBindInterfaceEmbedsTarget(t, name)
}
return nil, nil
}
func (b *builder) findBindStructEmbedsTarget(strukt *types.Struct, name string) (types.Object, error) {
var found types.Object
for i := 0; i < strukt.NumFields(); i++ {
field := strukt.Field(i)
if !field.Embedded() {
continue
}
fieldType := field.Type()
if ptr, ok := fieldType.(*types.Pointer); ok {
fieldType = ptr.Elem()
}
f, err := b.findBindTarget(fieldType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
}
return found, nil
}
func (b *builder) findBindInterfaceEmbedsTarget(iface *types.Interface, name string) (types.Object, error) {
var found types.Object
for i := 0; i < iface.NumEmbeddeds(); i++ {
embeddedType := iface.EmbeddedType(i)
f, err := b.findBindTarget(embeddedType, name)
if err != nil {
return nil, err
}
if f != nil && found != nil {
return nil, fmt.Errorf("found more than one way to bind for %s", name)
}
if f != nil {
found = f
}
}
return found, nil
}
func (f *Field) HasDirectives() bool {
return len(f.ImplDirectives()) > 0
}
func (f *Field) DirectiveObjName() string {
if f.Object.Root {
return "nil"
}
return f.GoReceiverName
}
func (f *Field) ImplDirectives() []*Directive {
var d []*Directive
loc := ast.LocationFieldDefinition
if f.Object.IsInputType() {
loc = ast.LocationInputFieldDefinition
}
for i := range f.Directives {
if !f.Directives[i].Builtin &&
(f.Directives[i].IsLocation(loc, ast.LocationObject) || f.Directives[i].IsLocation(loc, ast.LocationInputObject)) {
d = append(d, f.Directives[i])
}
}
return d
}
func (f *Field) IsReserved() bool {
return strings.HasPrefix(f.Name, "__")
}
func (f *Field) IsMethod() bool {
return f.GoFieldType == GoFieldMethod
}
func (f *Field) IsVariable() bool {
return f.GoFieldType == GoFieldVariable
}
func (f *Field) IsMap() bool {
return f.GoFieldType == GoFieldMap
}
func (f *Field) IsConcurrent() bool {
if f.Object.DisableConcurrency {
return false
}
return f.MethodHasContext || f.IsResolver
}
func (f *Field) GoNameUnexported() string {
return templates.ToGoPrivate(f.Name)
}
func (f *Field) ShortInvocation() string {
caser := cases.Title(language.English, cases.NoLower)
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName)
}
return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
}
func (f *Field) ArgsFunc() string {
if len(f.Args) == 0 {
return ""
}
return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
}
func (f *Field) FieldContextFunc() string {
return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name
}
func (f *Field) ChildFieldContextFunc(name string) string {
return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name
}
func (f *Field) ResolverType() string {
if !f.IsResolver {
return ""
}
return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}
func (f *Field) IsInputObject() bool {
return f.Object.Kind == ast.InputObject
}
func (f *Field) IsRoot() bool {
return f.Object.Root
}
func (f *Field) ShortResolverDeclaration() string {
return f.ShortResolverSignature(nil)
}
// ShortResolverSignature is identical to ShortResolverDeclaration,
// but respects previous naming (return) conventions, if any.
func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
templates.CurrentImports.LookupType(f.Object.Reference()),
templates.CurrentImports.LookupType(f.TypeReference.GO),
)
}
res := "(ctx context.Context"
if !f.Object.Root {
res += fmt.Sprintf(", obj %s", templates.CurrentImports.LookupType(f.Object.Reference()))
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
result := templates.CurrentImports.LookupType(f.TypeReference.GO)
if f.Object.Stream {
result = "<-chan " + result
}
// Named return.
var namedV, namedE string
if ft != nil {
if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 {
namedV = ft.Results.List[0].Names[0].Name
}
if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 {
namedE = ft.Results.List[1].Names[0].Name
}
}
res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
return res
}
func (f *Field) GoResultName() (string, bool) {
name := fmt.Sprintf("%v", f.TypeReference.GO)
splits := strings.Split(name, "/")
return splits[len(splits)-1], strings.HasPrefix(name, "[]")
}
func (f *Field) ComplexitySignature() string {
res := "func(childComplexity int"
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
}
res += ") int"
return res
}
func (f *Field) ComplexityArgs() string {
args := make([]string, len(f.Args))
for i, arg := range f.Args {
args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
}
return strings.Join(args, ", ")
}
func (f *Field) CallArgs() string {
args := make([]string, 0, len(f.Args)+2)
if f.IsResolver {
args = append(args, "rctx")
if !f.Object.Root {
args = append(args, "obj")
}
} else if f.MethodHasContext {
args = append(args, "ctx")
}
for _, arg := range f.Args {
tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() {
tmp = fmt.Sprintf(`
func () interface{} {
if fc.Args["%s"] == nil {
return nil
}
return fc.Args["%s"].(interface{})
}()`, arg.Name, arg.Name,
)
}
args = append(args, tmp)
}
return strings.Join(args, ", ")
}

View File

@@ -0,0 +1,158 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(ctx context.Context){{ end }}graphql.Marshaler) {
{{- $null := "graphql.Null" }}
{{- if $object.Stream }}
{{- $null = "nil" }}
{{- end }}
fc, err := ec.{{ $field.FieldContextFunc }}(ctx, field)
if err != nil {
return {{ $null }}
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = {{ $null }}
}
}()
{{- if $.AllDirectives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
{{ else }}
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
})
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
{{- end }}
if resTmp == nil {
{{- if $field.TypeReference.GQL.NonNull }}
if !graphql.HasFieldError(ctx, fc) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return {{ $null }}
}
{{- if $object.Stream }}
return func(ctx context.Context) graphql.Marshaler {
select {
case res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}}):
if !ok {
return nil
}
return graphql.WriterFunc(func(w io.Writer) {
w.Write([]byte{'{'})
graphql.MarshalString(field.Alias).MarshalGQL(w)
w.Write([]byte{':'})
ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res).MarshalGQL(w)
w.Write([]byte{'}'})
})
case <-ctx.Done():
return nil
}
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
fc.Result = res
return ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res)
{{- end }}
}
func (ec *executionContext) {{ $field.FieldContextFunc }}(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: {{quote $field.Object.Name}},
Field: field,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
IsResolver: {{ $field.IsResolver }},
Child: func (ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
{{- if not $field.TypeReference.Definition.Fields }}
return nil, errors.New("field of type {{ $field.TypeReference.Definition.Name }} does not have child fields")
{{- else if ne $field.TypeReference.Definition.Kind "OBJECT" }}
return nil, errors.New("FieldContext.Child cannot be called on type {{ $field.TypeReference.Definition.Kind }}")
{{- else }}
switch field.Name {
{{- range $f := $field.TypeReference.Definition.Fields }}
case "{{ $f.Name }}":
return ec.{{ $field.ChildFieldContextFunc $f.Name }}(ctx, field)
{{- end }}
}
return nil, fmt.Errorf("no field named %q was found under type {{ $field.TypeReference.Definition.Name }}", field.Name)
{{- end }}
},
}
{{- if $field.Args }}
defer func () {
if r := recover(); r != nil {
err = ec.Recover(ctx, r)
ec.Error(ctx, err)
}
}()
ctx = graphql.WithFieldContext(ctx, fc)
if fc.Args, err = ec.{{ $field.ArgsFunc }}(ctx, field.ArgumentMap(ec.Variables)); err != nil {
ec.Error(ctx, err)
return fc, err
}
{{- end }}
return fc, nil
}
{{- end }}{{- end}}
{{ define "field" }}
{{- if .HasDirectives -}}
directive0 := func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
}
{{ template "implDirectives" . }}
tmp, err := directive{{.ImplDirectives|len}}(rctx)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
if tmp == nil {
return nil, nil
}
if data, ok := tmp.({{if .Stream}}<-chan {{end}}{{ .TypeReference.GO | ref }}) ; ok {
return data, nil
}
return nil, fmt.Errorf(`unexpected type %T from directive, should be {{if .Stream}}<-chan {{end}}{{ .TypeReference.GO }}`, tmp)
{{- else -}}
ctx = rctx // use context from middleware stack in children
{{ template "fieldDefinition" . }}
{{- end -}}
{{ end }}
{{ define "fieldDefinition" }}
{{- if .IsResolver -}}
return ec.resolvers.{{ .ShortInvocation }}
{{- else if .IsMap -}}
switch v := {{.GoReceiverName}}[{{.Name|quote}}].(type) {
case {{if .Stream}}<-chan {{end}}{{.TypeReference.GO | ref}}:
return v, nil
case {{if .Stream}}<-chan {{end}}{{.TypeReference.Elem.GO | ref}}:
return &v, nil
case nil:
return ({{.TypeReference.GO | ref}})(nil), nil
default:
return nil, fmt.Errorf("unexpected type %T for field %s", v, {{ .Name | quote}})
}
{{- else if .IsMethod -}}
{{- if .VOkFunc -}}
v, ok := {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }})
if !ok {
return nil, nil
}
return v, nil
{{- else if .NoErr -}}
return {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }}), nil
{{- else -}}
return {{.GoReceiverName}}.{{.GoFieldName}}({{ .CallArgs }})
{{- end -}}
{{- else if .IsVariable -}}
return {{.GoReceiverName}}.{{.GoFieldName}}, nil
{{- end }}
{{- end }}

View File

@@ -0,0 +1,220 @@
package codegen
import (
"embed"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
)
//go:embed *.gotpl
var codegenTemplates embed.FS
func GenerateCode(data *Data) error {
if !data.Config.Exec.IsDefined() {
return fmt.Errorf("missing exec config")
}
switch data.Config.Exec.Layout {
case config.ExecLayoutSingleFile:
return generateSingleFile(data)
case config.ExecLayoutFollowSchema:
return generatePerSchema(data)
}
return fmt.Errorf("unrecognized exec layout %s", data.Config.Exec.Layout)
}
func generateSingleFile(data *Data) error {
return templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Filename: data.Config.Exec.Filename,
Data: data,
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}
func generatePerSchema(data *Data) error {
err := generateRootFile(data)
if err != nil {
return err
}
builds := map[string]*Data{}
err = addObjects(data, &builds)
if err != nil {
return err
}
err = addInputs(data, &builds)
if err != nil {
return err
}
err = addInterfaces(data, &builds)
if err != nil {
return err
}
err = addReferencedTypes(data, &builds)
if err != nil {
return err
}
for filename, build := range builds {
if filename == "" {
continue
}
dir := data.Config.Exec.DirName
path := filepath.Join(dir, filename)
err = templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Filename: path,
Data: build,
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
if err != nil {
return err
}
}
return nil
}
func filename(p *ast.Position, config *config.Config) string {
name := "common!"
if p != nil && p.Src != nil {
gqlname := filepath.Base(p.Src.Name)
ext := filepath.Ext(p.Src.Name)
name = strings.TrimSuffix(gqlname, ext)
}
filenameTempl := config.Exec.FilenameTemplate
if filenameTempl == "" {
filenameTempl = "{name}.generated.go"
}
return strings.ReplaceAll(filenameTempl, "{name}", name)
}
func addBuild(filename string, p *ast.Position, data *Data, builds *map[string]*Data) {
buildConfig := *data.Config
if p != nil {
buildConfig.Sources = []*ast.Source{p.Src}
}
(*builds)[filename] = &Data{
Config: &buildConfig,
QueryRoot: data.QueryRoot,
MutationRoot: data.MutationRoot,
SubscriptionRoot: data.SubscriptionRoot,
AllDirectives: data.AllDirectives,
}
}
// Root file contains top-level definitions that should not be duplicated across the generated
// files for each schema file.
func generateRootFile(data *Data) error {
dir := data.Config.Exec.DirName
path := filepath.Join(dir, "root_.generated.go")
_, thisFile, _, _ := runtime.Caller(0)
rootDir := filepath.Dir(thisFile)
templatePath := filepath.Join(rootDir, "root_.gotpl")
templateBytes, err := os.ReadFile(templatePath)
if err != nil {
return err
}
template := string(templateBytes)
return templates.Render(templates.Options{
PackageName: data.Config.Exec.Package,
Template: template,
Filename: path,
Data: data,
RegionTags: false,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}
func addObjects(data *Data, builds *map[string]*Data) error {
for _, o := range data.Objects {
filename := filename(o.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, o.Position, data, builds)
}
(*builds)[filename].Objects = append((*builds)[filename].Objects, o)
}
return nil
}
func addInputs(data *Data, builds *map[string]*Data) error {
for _, in := range data.Inputs {
filename := filename(in.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, in.Position, data, builds)
}
(*builds)[filename].Inputs = append((*builds)[filename].Inputs, in)
}
return nil
}
func addInterfaces(data *Data, builds *map[string]*Data) error {
for k, inf := range data.Interfaces {
filename := filename(inf.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, inf.Position, data, builds)
}
build := (*builds)[filename]
if build.Interfaces == nil {
build.Interfaces = map[string]*Interface{}
}
if build.Interfaces[k] != nil {
return errors.New("conflicting interface keys")
}
build.Interfaces[k] = inf
}
return nil
}
func addReferencedTypes(data *Data, builds *map[string]*Data) error {
for k, rt := range data.ReferencedTypes {
filename := filename(rt.Definition.Position, data.Config)
if (*builds)[filename] == nil {
addBuild(filename, rt.Definition.Position, data, builds)
}
build := (*builds)[filename]
if build.ReferencedTypes == nil {
build.ReferencedTypes = map[string]*config.TypeReference{}
}
if build.ReferencedTypes[k] != nil {
return errors.New("conflicting referenced type keys")
}
build.ReferencedTypes[k] = rt
}
return nil
}

View File

@@ -0,0 +1,300 @@
{{ reserveImport "context" }}
{{ reserveImport "fmt" }}
{{ reserveImport "io" }}
{{ reserveImport "strconv" }}
{{ reserveImport "time" }}
{{ reserveImport "sync" }}
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "embed" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }}
{{ if eq .Config.Exec.Layout "single-file" }}
// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.
func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}
type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}
type ResolverRoot interface {
{{- range $object := .Objects -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
{{- range $object := .Inputs -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
}
type DirectiveRoot struct {
{{ range $directive := .Directives }}
{{- $directive.Declaration }}
{{ end }}
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{- end }}
}
{{- end }}
{{ end }}
{{- end }}
}
{{ end }}
{{ range $object := .Objects -}}
{{ if $object.HasResolvers }}
type {{ucFirst $object.Name}}Resolver interface {
{{ range $field := $object.Fields -}}
{{- if $field.IsResolver }}
{{- $field.GoFieldName}}{{ $field.ShortResolverDeclaration }}
{{- end }}
{{ end }}
}
{{- end }}
{{- end }}
{{ range $object := .Inputs -}}
{{ if $object.HasResolvers }}
type {{$object.Name}}Resolver interface {
{{ range $field := $object.Fields -}}
{{- if $field.IsResolver }}
{{- $field.GoFieldName}}{{ $field.ShortResolverDeclaration }}
{{- end }}
{{ end }}
}
{{- end }}
{{- end }}
{{ if eq .Config.Exec.Layout "single-file" }}
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}
func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e, 0, 0, nil}
_ = ec
{{ if not .Config.OmitComplexity -}}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e, 0, 0, make(chan graphql.DeferredResult)}
inputUnmarshalMap := graphql.BuildUnmarshalerMap(
{{- range $input := .Inputs -}}
{{ if not $input.HasUnmarshal }}
ec.unmarshalInput{{ $input.Name }},
{{- end }}
{{- end }}
)
first := true
switch rc.Operation.Operation {
{{- if .QueryRoot }} case ast.Query:
return func(ctx context.Context) *graphql.Response {
var response graphql.Response
var data graphql.Marshaler
if first {
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "QUERY" -}}
data = ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data = ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
} else {
if atomic.LoadInt32(&ec.pendingDeferred) > 0 {
result := <-ec.deferredResults
atomic.AddInt32(&ec.pendingDeferred, -1)
data = result.Result
response.Path = result.Path
response.Label = result.Label
response.Errors = result.Errors
} else {
return nil
}
}
var buf bytes.Buffer
data.MarshalGQL(&buf)
response.Data = buf.Bytes()
if atomic.LoadInt32(&ec.deferred) > 0 {
hasNext := atomic.LoadInt32(&ec.pendingDeferred) > 0
response.HasNext = &hasNext
}
return &response
}
{{ end }}
{{- if .MutationRoot }} case ast.Mutation:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .SubscriptionRoot }} case ast.Subscription:
{{ if .Directives.LocationDirectives "SUBSCRIPTION" -}}
next := ec._subscriptionMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet),nil
})
{{- else -}}
next := ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next(ctx)
if data == nil {
return nil
}
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
default:
return graphql.OneShot(graphql.ErrorResponse(ctx, "unsupported GraphQL operation"))
}
}
type executionContext struct {
*graphql.OperationContext
*executableSchema
deferred int32
pendingDeferred int32
deferredResults chan graphql.DeferredResult
}
func (ec *executionContext) processDeferredGroup(dg graphql.DeferredGroup) {
atomic.AddInt32(&ec.pendingDeferred, 1)
go func () {
ctx := graphql.WithFreshResponseContext(dg.Context)
dg.FieldSet.Dispatch(ctx)
ds := graphql.DeferredResult{
Path: dg.Path,
Label: dg.Label,
Result: dg.FieldSet,
Errors: graphql.GetErrors(ctx),
}
// null fields should bubble up
if dg.FieldSet.Invalids > 0 {
ds.Result = graphql.Null
}
ec.deferredResults <- ds
}()
}
func (ec *executionContext) introspectSchema() (*introspection.Schema, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapSchema(parsedSchema), nil
}
func (ec *executionContext) introspectType(name string) (*introspection.Type, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
{{if .HasEmbeddableSources }}
//go:embed{{- range $source := .AugmentedSources }}{{if $source.Embeddable}} {{$source.RelativePath|quote}}{{end}}{{- end }}
var sourcesFS embed.FS
func sourceData(filename string) string {
data, err := sourcesFS.ReadFile(filename)
if err != nil {
panic(fmt.Sprintf("codegen problem: %s not available", filename))
}
return string(data)
}
{{- end }}
var sources = []*ast.Source{
{{- range $source := .AugmentedSources }}
{Name: {{$source.RelativePath|quote}}, Input: {{if (not $source.Embeddable)}}{{$source.Source|rawQuote}}{{else}}sourceData({{$source.RelativePath|quote}}){{end}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)
{{ end }}

View File

@@ -0,0 +1,94 @@
{{- range $input := .Inputs }}
{{- if not .HasUnmarshal }}
{{- $it := "it" }}
{{- if .PointersInUmarshalInput }}
{{- $it = "&it" }}
{{- end }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{ if .PointersInUmarshalInput }}*{{ end }}{{.Type | ref}}, error) {
var it {{.Type | ref}}
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
asMap[k] = v
}
{{ range $field := .Fields}}
{{- if notNil "Default" $field }}
if _, present := asMap[{{$field.Name|quote}}] ; !present {
asMap[{{$field.Name|quote}}] = {{ $field.Default | dump }}
}
{{- end}}
{{- end }}
fieldsInOrder := [...]string{ {{ range .Fields }}{{ quote .Name }},{{ end }} }
for _, k := range fieldsInOrder {
v, ok := asMap[k]
if !ok {
continue
}
switch k {
{{- range $field := .Fields }}
case {{$field.Name|quote}}:
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField({{$field.Name|quote}}))
{{- if $field.ImplDirectives }}
directive0 := func(ctx context.Context) (interface{}, error) { return ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v) }
{{ template "implDirectives" $field }}
tmp, err := directive{{$field.ImplDirectives|len}}(ctx)
if err != nil {
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok {
{{- if $field.IsResolver }}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return {{$it}}, err
}
{{- else }}
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf(data)
{{- else }}
it.{{$field.GoFieldName}} = data
{{- end }}
{{- end }}
{{- if $field.TypeReference.IsNilable }}
{{- if not $field.IsResolver }}
} else if tmp == nil {
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf[{{ $field.TypeReference.GO | ref }}](nil)
{{- else }}
it.{{$field.GoFieldName}} = nil
{{- end }}
{{- end }}
{{- end }}
} else {
err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
{{- else }}
{{- if $field.IsResolver }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return {{$it}}, err
}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return {{$it}}, err
}
{{- else }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return {{$it}}, err
}
{{- if $field.TypeReference.IsOmittable }}
it.{{$field.GoFieldName}} = graphql.OmittableOf(data)
{{- else }}
it.{{$field.GoFieldName}} = data
{{- end }}
{{- end }}
{{- end }}
{{- end }}
}
}
return {{$it}}, nil
}
{{- end }}
{{ end }}

View File

@@ -0,0 +1,87 @@
package codegen
import (
"fmt"
"go/types"
"github.com/vektah/gqlparser/v2/ast"
"github.com/99designs/gqlgen/codegen/config"
)
type Interface struct {
*ast.Definition
Type types.Type
Implementors []InterfaceImplementor
InTypemap bool
}
type InterfaceImplementor struct {
*ast.Definition
Type types.Type
TakeRef bool
}
func (b *builder) buildInterface(typ *ast.Definition) (*Interface, error) {
obj, err := b.Binder.DefaultUserObject(typ.Name)
if err != nil {
panic(err)
}
i := &Interface{
Definition: typ,
Type: obj,
InTypemap: b.Config.Models.UserDefined(typ.Name),
}
interfaceType, err := findGoInterface(i.Type)
if interfaceType == nil || err != nil {
return nil, fmt.Errorf("%s is not an interface", i.Type)
}
for _, implementor := range b.Schema.GetPossibleTypes(typ) {
obj, err := b.Binder.DefaultUserObject(implementor.Name)
if err != nil {
return nil, fmt.Errorf("%s has no backing go type", implementor.Name)
}
implementorType, err := findGoNamedType(obj)
if err != nil {
return nil, fmt.Errorf("can not find backing go type %s: %w", obj.String(), err)
} else if implementorType == nil {
return nil, fmt.Errorf("can not find backing go type %s", obj.String())
}
anyValid := false
// first check if the value receiver can be nil, eg can we type switch on case Thing:
if types.Implements(implementorType, interfaceType) {
i.Implementors = append(i.Implementors, InterfaceImplementor{
Definition: implementor,
Type: obj,
TakeRef: !types.IsInterface(obj),
})
anyValid = true
}
// then check if the pointer receiver can be nil, eg can we type switch on case *Thing:
if types.Implements(types.NewPointer(implementorType), interfaceType) {
i.Implementors = append(i.Implementors, InterfaceImplementor{
Definition: implementor,
Type: types.NewPointer(obj),
})
anyValid = true
}
if !anyValid {
return nil, fmt.Errorf("%s does not satisfy the interface %s", implementorType.String(), i.Type.String())
}
}
return i, nil
}
func (i *InterfaceImplementor) CanBeNil() bool {
return config.IsNilable(i.Type)
}

View File

@@ -0,0 +1,21 @@
{{- range $interface := .Interfaces }}
func (ec *executionContext) _{{$interface.Name}}(ctx context.Context, sel ast.SelectionSet, obj {{$interface.Type | ref}}) graphql.Marshaler {
switch obj := (obj).(type) {
case nil:
return graphql.Null
{{- range $implementor := $interface.Implementors }}
case {{$implementor.Type | ref}}:
{{- if $implementor.CanBeNil }}
if obj == nil {
return graphql.Null
}
{{- end }}
return ec._{{$implementor.Name}}(ctx, sel, {{ if $implementor.TakeRef }}&{{ end }}obj)
{{- end }}
default:
panic(fmt.Errorf("unexpected type %T", obj))
}
}
{{- end }}

View File

@@ -0,0 +1,183 @@
package codegen
import (
"fmt"
"go/types"
"strconv"
"strings"
"unicode"
"github.com/99designs/gqlgen/codegen/config"
"github.com/vektah/gqlparser/v2/ast"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
type GoFieldType int
const (
GoFieldUndefined GoFieldType = iota
GoFieldMethod
GoFieldVariable
GoFieldMap
)
type Object struct {
*ast.Definition
Type types.Type
ResolverInterface types.Type
Root bool
Fields []*Field
Implements []*ast.Definition
DisableConcurrency bool
Stream bool
Directives []*Directive
PointersInUmarshalInput bool
}
func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
dirs, err := b.getDirectives(typ.Directives)
if err != nil {
return nil, fmt.Errorf("%s: %w", typ.Name, err)
}
caser := cases.Title(language.English, cases.NoLower)
obj := &Object{
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
PointersInUmarshalInput: b.Config.ReturnPointersInUmarshalInput,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), caser.String(typ.Name)+"Resolver", nil),
nil,
nil,
),
}
if !obj.Root {
goObject, err := b.Binder.DefaultUserObject(typ.Name)
if err != nil {
return nil, err
}
obj.Type = goObject
}
for _, intf := range b.Schema.GetImplements(typ) {
obj.Implements = append(obj.Implements, b.Schema.Types[intf.Name])
}
for _, field := range typ.Fields {
if strings.HasPrefix(field.Name, "__") {
continue
}
var f *Field
f, err = b.buildField(obj, field)
if err != nil {
return nil, err
}
obj.Fields = append(obj.Fields, f)
}
return obj, nil
}
func (o *Object) Reference() types.Type {
if config.IsNilable(o.Type) {
return o.Type
}
return types.NewPointer(o.Type)
}
type Objects []*Object
func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.Name)
for _, s := range o.Implements {
satisfiedBy += ", " + strconv.Quote(s.Name)
}
return "[]string{" + satisfiedBy + "}"
}
func (o *Object) HasResolvers() bool {
for _, f := range o.Fields {
if f.IsResolver {
return true
}
}
return false
}
func (o *Object) HasUnmarshal() bool {
if o.Type == config.MapType {
return true
}
for i := 0; i < o.Type.(*types.Named).NumMethods(); i++ {
if o.Type.(*types.Named).Method(i).Name() == "UnmarshalGQL" {
return true
}
}
return false
}
func (o *Object) HasDirectives() bool {
if len(o.Directives) > 0 {
return true
}
for _, f := range o.Fields {
if f.HasDirectives() {
return true
}
}
return false
}
func (o *Object) IsConcurrent() bool {
for _, f := range o.Fields {
if f.IsConcurrent() {
return true
}
}
return false
}
func (o *Object) IsReserved() bool {
return strings.HasPrefix(o.Definition.Name, "__")
}
func (o *Object) Description() string {
return o.Definition.Description
}
func (o *Object) HasField(name string) bool {
for _, f := range o.Fields {
if f.Name == name {
return true
}
}
return false
}
func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.Definition.Name, name) {
return os[i]
}
}
return nil
}
func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}

View File

@@ -0,0 +1,147 @@
{{- range $object := .Objects }}
var {{ $object.Name|lcFirst}}Implementors = {{$object.Implementors}}
{{- if .Stream }}
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors)
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: {{$object.Name|quote}},
})
if len(fields) != 1 {
ec.Errorf(ctx, "must subscribe to exactly one stream")
return nil
}
switch fields[0].Name {
{{- range $field := $object.Fields }}
case "{{$field.Name}}":
return ec._{{$object.Name}}_{{$field.Name}}(ctx, fields[0])
{{- end }}
default:
panic("unknown field " + strconv.Quote(fields[0].Name))
}
}
{{- else }}
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet{{ if not $object.Root }},obj {{$object.Reference | ref }}{{ end }}) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors)
{{- if $object.Root }}
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: {{$object.Name|quote}},
})
{{end}}
out := graphql.NewFieldSet(fields)
deferred := make(map[string]*graphql.FieldSet)
for i, field := range fields {
{{- if $object.Root }}
innerCtx := graphql.WithRootFieldContext(ctx, &graphql.RootFieldContext{
Object: field.Name,
Field: field,
})
{{end}}
switch field.Name {
case "__typename":
out.Values[i] = graphql.MarshalString({{$object.Name|quote}})
{{- range $field := $object.Fields }}
case "{{$field.Name}}":
{{- if $field.IsConcurrent }}
field := field
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._{{$object.Name}}_{{$field.Name}}(ctx, field{{if not $object.Root}}, obj{{end}})
{{- if $field.TypeReference.GQL.NonNull }}
if res == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&fs.Invalids, 1)
{{- else }}
fs.Invalids++
{{- end }}
}
{{- end }}
return res
}
{{if $object.Root}}
rrm := func(ctx context.Context) graphql.Marshaler {
return ec.OperationContext.RootResolverMiddleware(ctx,
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
{{end}}
{{if not $object.Root}}
if field.Deferrable != nil {
dfs, ok := deferred[field.Deferrable.Label]
di := 0
if ok {
dfs.AddField(field)
di = len(dfs.Values) - 1
} else {
dfs = graphql.NewFieldSet([]graphql.CollectedField{field})
deferred[field.Deferrable.Label] = dfs
}
dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler {
return innerFunc(ctx, dfs)
})
// don't run the out.Concurrently() call below
out.Values[i] = graphql.Null
continue
}
{{end}}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler {
{{- if $object.Root -}}
return rrm(innerCtx)
{{- else -}}
return innerFunc(ctx, out)
{{- end -}}
})
{{- else }}
{{- if $object.Root -}}
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
return ec._{{$object.Name}}_{{$field.Name}}(ctx, field)
})
{{- else -}}
out.Values[i] = ec._{{$object.Name}}_{{$field.Name}}(ctx, field, obj)
{{- end -}}
{{- if $field.TypeReference.GQL.NonNull }}
if out.Values[i] == graphql.Null {
{{- if $object.IsConcurrent }}
atomic.AddUint32(&out.Invalids, 1)
{{- else }}
out.Invalids++
{{- end }}
}
{{- end }}
{{- end }}
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch(ctx)
if out.Invalids > 0 { return graphql.Null }
atomic.AddInt32(&ec.deferred, int32(len(deferred)))
for label, dfs := range deferred {
ec.processDeferredGroup(graphql.DeferredGroup{
Label: label,
Path: graphql.GetPath(ctx),
FieldSet: dfs,
Context: ctx,
})
}
return out
}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,273 @@
{{ reserveImport "context" }}
{{ reserveImport "fmt" }}
{{ reserveImport "io" }}
{{ reserveImport "strconv" }}
{{ reserveImport "time" }}
{{ reserveImport "sync" }}
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "embed" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql" }}
{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }}
// NewExecutableSchema creates an ExecutableSchema from the ResolverRoot interface.
func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}
type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}
type ResolverRoot interface {
{{- range $object := .Objects -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
{{- range $object := .Inputs -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
}
type DirectiveRoot struct {
{{ range $directive := .Directives }}
{{- $directive.Declaration }}
{{ end }}
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
{{ range $_, $fields := $object.UniqueFields }}
{{- $field := index $fields 0 -}}
{{ if not $field.IsReserved -}}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{- end }}
}
{{- end }}
{{ end }}
{{- end }}
}
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}
func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e, 0, 0, nil}
_ = ec
{{- if not .Config.OmitComplexity }}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
{{ range $_, $fields := $object.UniqueFields }}
{{- $len := len $fields }}
{{- range $i, $field := $fields }}
{{- $last := eq (add $i 1) $len }}
{{- if not $field.IsReserved }}
{{- if eq $i 0 }}case {{ end }}"{{$object.Name}}.{{$field.Name}}"{{ if not $last }},{{ else }}:
if e.complexity.{{ucFirst $object.Name }}.{{$field.GoFieldName}} == nil {
break
}
{{ if $field.Args }}
args, err := ec.{{ $field.ArgsFunc }}(context.TODO(),rawArgs)
if err != nil {
return 0, false
}
{{ end }}
return e.complexity.{{ucFirst $object.Name}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{ end }}), true
{{ end }}
{{- end }}
{{- end }}
{{ end }}
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e, 0, 0, make(chan graphql.DeferredResult)}
inputUnmarshalMap := graphql.BuildUnmarshalerMap(
{{- range $input := .Inputs -}}
{{ if not $input.HasUnmarshal }}
ec.unmarshalInput{{ $input.Name }},
{{- end }}
{{- end }}
)
first := true
switch rc.Operation.Operation {
{{- if .QueryRoot }} case ast.Query:
return func(ctx context.Context) *graphql.Response {
var response graphql.Response
var data graphql.Marshaler
if first {
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "QUERY" -}}
data = ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data = ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
} else {
if atomic.LoadInt32(&ec.pendingDeferred) > 0 {
result := <-ec.deferredResults
atomic.AddInt32(&ec.pendingDeferred, -1)
data = result.Result
response.Path = result.Path
response.Label = result.Label
response.Errors = result.Errors
} else {
return nil
}
}
var buf bytes.Buffer
data.MarshalGQL(&buf)
response.Data = buf.Bytes()
if atomic.LoadInt32(&ec.deferred) > 0 {
hasNext := atomic.LoadInt32(&ec.pendingDeferred) > 0
response.HasNext = &hasNext
}
return &response
}
{{ end }}
{{- if .MutationRoot }} case ast.Mutation:
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
})
{{- else -}}
data := ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
{{- if .SubscriptionRoot }} case ast.Subscription:
{{ if .Directives.LocationDirectives "SUBSCRIPTION" -}}
next := ec._subscriptionMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet),nil
})
{{- else -}}
next := ec._{{.SubscriptionRoot.Name}}(ctx, rc.Operation.SelectionSet)
{{- end }}
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next(ctx)
if data == nil {
return nil
}
data.MarshalGQL(&buf)
return &graphql.Response{
Data: buf.Bytes(),
}
}
{{ end }}
default:
return graphql.OneShot(graphql.ErrorResponse(ctx, "unsupported GraphQL operation"))
}
}
type executionContext struct {
*graphql.OperationContext
*executableSchema
deferred int32
pendingDeferred int32
deferredResults chan graphql.DeferredResult
}
func (ec *executionContext) processDeferredGroup(dg graphql.DeferredGroup) {
atomic.AddInt32(&ec.pendingDeferred, 1)
go func () {
ctx := graphql.WithFreshResponseContext(dg.Context)
dg.FieldSet.Dispatch(ctx)
ds := graphql.DeferredResult{
Path: dg.Path,
Label: dg.Label,
Result: dg.FieldSet,
Errors: graphql.GetErrors(ctx),
}
// null fields should bubble up
if dg.FieldSet.Invalids > 0 {
ds.Result = graphql.Null
}
ec.deferredResults <- ds
}()
}
func (ec *executionContext) introspectSchema() (*introspection.Schema, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapSchema(parsedSchema), nil
}
func (ec *executionContext) introspectType(name string) (*introspection.Type, error) {
if ec.DisableIntrospection {
return nil, errors.New("introspection disabled")
}
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
{{if .HasEmbeddableSources }}
//go:embed{{- range $source := .AugmentedSources }}{{if $source.Embeddable}} {{$source.RelativePath|quote}}{{end}}{{- end }}
var sourcesFS embed.FS
func sourceData(filename string) string {
data, err := sourcesFS.ReadFile(filename)
if err != nil {
panic(fmt.Sprintf("codegen problem: %s not available", filename))
}
return string(data)
}
{{- end}}
var sources = []*ast.Source{
{{- range $source := .AugmentedSources }}
{Name: {{$source.RelativePath|quote}}, Input: {{if (not $source.Embeddable)}}{{$source.Source|rawQuote}}{{else}}sourceData({{$source.RelativePath|quote}}){{end}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)

View File

@@ -0,0 +1,139 @@
package templates
import (
"fmt"
"go/types"
"strconv"
"strings"
"github.com/99designs/gqlgen/internal/code"
)
type Import struct {
Name string
Path string
Alias string
}
type Imports struct {
imports []*Import
destDir string
packages *code.Packages
}
func (i *Import) String() string {
if strings.HasSuffix(i.Path, i.Alias) {
return strconv.Quote(i.Path)
}
return i.Alias + " " + strconv.Quote(i.Path)
}
func (s *Imports) String() string {
res := ""
for i, imp := range s.imports {
if i != 0 {
res += "\n"
}
res += imp.String()
}
return res
}
func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
if path == "" {
panic("empty ambient import")
}
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return "", nil
}
name := s.packages.NameForPackage(path)
var alias string
if len(aliases) != 1 {
alias = name
} else {
alias = aliases[0]
}
if existing := s.findByPath(path); existing != nil {
if existing.Alias == alias {
return "", nil
}
return "", fmt.Errorf("ambient import already exists")
}
if alias := s.findByAlias(alias); alias != nil {
return "", fmt.Errorf("ambient import collides on an alias")
}
s.imports = append(s.imports, &Import{
Name: name,
Path: path,
Alias: alias,
})
return "", nil
}
func (s *Imports) Lookup(path string) string {
if path == "" {
return ""
}
path = code.NormalizeVendor(path)
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return ""
}
if existing := s.findByPath(path); existing != nil {
return existing.Alias
}
imp := &Import{
Name: s.packages.NameForPackage(path),
Path: path,
}
s.imports = append(s.imports, imp)
alias := imp.Name
i := 1
for s.findByAlias(alias) != nil {
alias = imp.Name + strconv.Itoa(i)
i++
if i > 1000 {
panic(fmt.Errorf("too many collisions, last attempt was %s", alias))
}
}
imp.Alias = alias
return imp.Alias
}
func (s *Imports) LookupType(t types.Type) string {
return types.TypeString(t, func(i *types.Package) string {
return s.Lookup(i.Path())
})
}
func (s Imports) findByPath(importPath string) *Import {
for _, imp := range s.imports {
if imp.Path == importPath {
return imp
}
}
return nil
}
func (s Imports) findByAlias(alias string) *Import {
for _, imp := range s.imports {
if imp.Alias == alias {
return imp
}
}
return nil
}

View File

@@ -0,0 +1,670 @@
package templates
import (
"bytes"
"fmt"
"go/types"
"io/fs"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"text/template"
"unicode"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/internal/imports"
)
// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
var CurrentImports *Imports
// Options specify various parameters to rendering a template.
type Options struct {
// PackageName is a helper that specifies the package header declaration.
// In other words, when you write the template you don't need to specify `package X`
// at the top of the file. By providing PackageName in the Options, the Render
// function will do that for you.
PackageName string
// Template is a string of the entire template that
// will be parsed and rendered. If it's empty,
// the plugin processor will look for .gotpl files
// in the same directory of where you wrote the plugin.
Template string
// Use the go:embed API to collect all the template files you want to pass into Render
// this is an alternative to passing the Template option
TemplateFS fs.FS
// Filename is the name of the file that will be
// written to the system disk once the template is rendered.
Filename string
RegionTags bool
GeneratedHeader bool
// PackageDoc is documentation written above the package line
PackageDoc string
// FileNotice is notice written below the package line
FileNotice string
// Data will be passed to the template execution.
Data interface{}
Funcs template.FuncMap
// Packages cache, you can find me on config.Config
Packages *code.Packages
}
var (
modelNamesMu sync.Mutex
modelNames = make(map[string]string, 0)
goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]")
)
// Render renders a gql plugin template from the given Options. Render is an
// abstraction of the text/template package that makes it easier to write gqlgen
// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
// files inside the directory where you wrote the plugin.
func Render(cfg Options) error {
if CurrentImports != nil {
panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
}
CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
funcs := Funcs()
for n, f := range cfg.Funcs {
funcs[n] = f
}
t := template.New("").Funcs(funcs)
t, err := parseTemplates(cfg, t)
if err != nil {
return err
}
roots := make([]string, 0, len(t.Templates()))
for _, template := range t.Templates() {
// templates that end with _.gotpl are special files we don't want to include
if strings.HasSuffix(template.Name(), "_.gotpl") ||
// filter out templates added with {{ template xxx }} syntax inside the template file
!strings.HasSuffix(template.Name(), ".gotpl") {
continue
}
roots = append(roots, template.Name())
}
// then execute all the important looking ones in order, adding them to the same file
sort.Slice(roots, func(i, j int) bool {
// important files go first
if strings.HasSuffix(roots[i], "!.gotpl") {
return true
}
if strings.HasSuffix(roots[j], "!.gotpl") {
return false
}
return roots[i] < roots[j]
})
var buf bytes.Buffer
for _, root := range roots {
if cfg.RegionTags {
buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n")
}
err := t.Lookup(root).Execute(&buf, cfg.Data)
if err != nil {
return fmt.Errorf("%s: %w", root, err)
}
if cfg.RegionTags {
buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")
}
}
var result bytes.Buffer
if cfg.GeneratedHeader {
result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
}
if cfg.PackageDoc != "" {
result.WriteString(cfg.PackageDoc + "\n")
}
result.WriteString("package ")
result.WriteString(cfg.PackageName)
result.WriteString("\n\n")
if cfg.FileNotice != "" {
result.WriteString(cfg.FileNotice)
result.WriteString("\n\n")
}
result.WriteString("import (\n")
result.WriteString(CurrentImports.String())
result.WriteString(")\n")
_, err = buf.WriteTo(&result)
if err != nil {
return err
}
CurrentImports = nil
err = write(cfg.Filename, result.Bytes(), cfg.Packages)
if err != nil {
return err
}
cfg.Packages.Evict(code.ImportPathForDir(filepath.Dir(cfg.Filename)))
return nil
}
func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return nil, fmt.Errorf("error with provided template: %w", err)
}
return t, nil
}
var fileSystem fs.FS
if cfg.TemplateFS != nil {
fileSystem = cfg.TemplateFS
} else {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
fileSystem = os.DirFS(rootDir)
}
t, err := t.ParseFS(fileSystem, "*.gotpl")
if err != nil {
return nil, fmt.Errorf("locating templates: %w", err)
}
return t, nil
}
func center(width int, pad string, s string) string {
if len(s)+2 > width {
return s
}
lpad := (width - len(s)) / 2
rpad := width - (lpad + len(s))
return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)
}
func Funcs() template.FuncMap {
return template.FuncMap{
"ucFirst": UcFirst,
"lcFirst": LcFirst,
"quote": strconv.Quote,
"rawQuote": rawQuote,
"dump": Dump,
"ref": ref,
"ts": config.TypeIdentifier,
"call": Call,
"prefixLines": prefixLines,
"notNil": notNil,
"reserveImport": CurrentImports.Reserve,
"lookupImport": CurrentImports.Lookup,
"go": ToGo,
"goPrivate": ToGoPrivate,
"goModelName": ToGoModelName,
"goPrivateModelName": ToGoPrivateModelName,
"add": func(a, b int) int {
return a + b
},
"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
return render(resolveName(filename, 0), tpldata)
},
}
}
func UcFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}
func LcFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToLower(r[0])
return string(r)
}
func isDelimiter(c rune) bool {
return c == '-' || c == '_' || unicode.IsSpace(c)
}
func ref(p types.Type) string {
return CurrentImports.LookupType(p)
}
func Call(p *types.Func) string {
pkg := CurrentImports.Lookup(p.Pkg().Path())
if pkg != "" {
pkg += "."
}
if p.Type() != nil {
// make sure the returned type is listed in our imports.
ref(p.Type().(*types.Signature).Results().At(0).Type())
}
return pkg + p.Name()
}
func resetModelNames() {
modelNamesMu.Lock()
defer modelNamesMu.Unlock()
modelNames = make(map[string]string, 0)
}
func buildGoModelNameKey(parts []string) string {
const sep = ":"
return strings.Join(parts, sep)
}
func goModelName(primaryToGoFunc func(string) string, parts []string) string {
modelNamesMu.Lock()
defer modelNamesMu.Unlock()
var (
goNameKey string
partLen int
nameExists = func(n string) bool {
for _, v := range modelNames {
if n == v {
return true
}
}
return false
}
applyToGoFunc = func(parts []string) string {
var out string
switch len(parts) {
case 0:
return ""
case 1:
return primaryToGoFunc(parts[0])
default:
out = primaryToGoFunc(parts[0])
}
for _, p := range parts[1:] {
out = fmt.Sprintf("%s%s", out, ToGo(p))
}
return out
}
applyValidGoName = func(parts []string) string {
var out string
for _, p := range parts {
out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p))
}
return out
}
)
// build key for this entity
goNameKey = buildGoModelNameKey(parts)
// determine if we've seen this entity before, and reuse if so
if goName, ok := modelNames[goNameKey]; ok {
return goName
}
// attempt first pass
if goName := applyToGoFunc(parts); !nameExists(goName) {
modelNames[goNameKey] = goName
return goName
}
// determine number of parts
partLen = len(parts)
// if there is only 1 part, append incrementing number until no conflict
if partLen == 1 {
base := applyToGoFunc(parts)
for i := 0; ; i++ {
tmp := fmt.Sprintf("%s%d", base, i)
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
}
// best effort "pretty" name
for i := partLen - 1; i >= 1; i-- {
tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:]))
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
// finally, fallback to just adding an incrementing number
base := applyToGoFunc(parts)
for i := 0; ; i++ {
tmp := fmt.Sprintf("%s%d", base, i)
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
}
func ToGoModelName(parts ...string) string {
return goModelName(ToGo, parts)
}
func ToGoPrivateModelName(parts ...string) string {
return goModelName(ToGoPrivate, parts)
}
func replaceInvalidCharacters(in string) string {
return goNameRe.ReplaceAllLiteralString(in, "_")
}
func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) {
return func(info *wordInfo) {
word := info.Word
switch {
case private && info.WordOffset == 0:
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// ID → id, CAMEL → camel
word = strings.ToLower(info.Word)
} else {
// ITicket → iTicket
word = LcFirst(info.Word)
}
case info.MatchCommonInitial:
word = strings.ToUpper(word)
case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word):
// FOO or foo → Foo
// FOo → FOo
word = UcFirst(strings.ToLower(word))
}
*nameRunes = append(*nameRunes, []rune(word)...)
}
}
func ToGo(name string) string {
if name == "_" {
return "_"
}
runes := make([]rune, 0, len(name))
wordWalker(name, wordWalkerFunc(false, &runes))
return string(runes)
}
func ToGoPrivate(name string) string {
if name == "_" {
return "_"
}
runes := make([]rune, 0, len(name))
wordWalker(name, wordWalkerFunc(true, &runes))
return sanitizeKeywords(string(runes))
}
type wordInfo struct {
WordOffset int
Word string
MatchCommonInitial bool
HasCommonInitial bool
}
// This function is based on the following code.
// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
func wordWalker(str string, f func(*wordInfo)) {
runes := []rune(strings.TrimFunc(str, isDelimiter))
w, i, wo := 0, 0, 0 // index of start of word, scan, word offset
hasCommonInitial := false
for i+1 <= len(runes) {
eow := false // whether we hit the end of a word
switch {
case i+1 == len(runes):
eow = true
case isDelimiter(runes[i+1]):
// underscore; shift the remainder forward over any run of underscores
eow = true
n := 1
for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {
n++
}
// Leave at most one underscore if the underscore is between two digits
if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
n--
}
copy(runes[i+1:], runes[i+n+1:])
runes = runes[:len(runes)-n]
case unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]):
// lower->non-lower
eow = true
}
i++
initialisms := config.GetInitialisms()
// [w,i) is a word.
word := string(runes[w:i])
if !eow && initialisms[word] && !unicode.IsLower(runes[i]) {
// through
// split IDFoo → ID, Foo
// but URLs → URLs
} else if !eow {
if initialisms[word] {
hasCommonInitial = true
}
continue
}
matchCommonInitial := false
upperWord := strings.ToUpper(word)
if initialisms[upperWord] {
// If the uppercase word (string(runes[w:i]) is "ID" or "IP"
// AND
// the word is the first two characters of the str
// AND
// that is not the end of the word
// AND
// the length of the string is greater than 3
// AND
// the third rune is an uppercase one
// THEN
// do NOT count this as an initialism.
switch upperWord {
case "ID", "IP":
if word == str[:2] && !eow && len(str) > 3 && unicode.IsUpper(runes[3]) {
continue
}
}
hasCommonInitial = true
matchCommonInitial = true
}
f(&wordInfo{
WordOffset: wo,
Word: word,
MatchCommonInitial: matchCommonInitial,
HasCommonInitial: hasCommonInitial,
})
hasCommonInitial = false
w = i
wo++
}
}
var keywords = []string{
"break",
"default",
"func",
"interface",
"select",
"case",
"defer",
"go",
"map",
"struct",
"chan",
"else",
"goto",
"package",
"switch",
"const",
"fallthrough",
"if",
"range",
"type",
"continue",
"for",
"import",
"return",
"var",
"_",
}
// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
func sanitizeKeywords(name string) string {
for _, k := range keywords {
if name == k {
return name + "Arg"
}
}
return name
}
func rawQuote(s string) string {
return "`" + strings.ReplaceAll(s, "`", "`+\"`\"+`") + "`"
}
func notNil(field string, data interface{}) bool {
v := reflect.ValueOf(data)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return false
}
val := v.FieldByName(field)
return val.IsValid() && !val.IsNil()
}
func Dump(val interface{}) string {
switch val := val.(type) {
case int:
return strconv.Itoa(val)
case int64:
return fmt.Sprintf("%d", val)
case float64:
return fmt.Sprintf("%f", val)
case string:
return strconv.Quote(val)
case bool:
return strconv.FormatBool(val)
case nil:
return "nil"
case []interface{}:
var parts []string
for _, part := range val {
parts = append(parts, Dump(part))
}
return "[]interface{}{" + strings.Join(parts, ",") + "}"
case map[string]interface{}:
buf := bytes.Buffer{}
buf.WriteString("map[string]interface{}{")
var keys []string
for key := range val {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
data := val[key]
buf.WriteString(strconv.Quote(key))
buf.WriteString(":")
buf.WriteString(Dump(data))
buf.WriteString(",")
}
buf.WriteString("}")
return buf.String()
default:
panic(fmt.Errorf("unsupported type %T", val))
}
}
func prefixLines(prefix, s string) string {
return prefix + strings.ReplaceAll(s, "\n", "\n"+prefix)
}
func resolveName(name string, skip int) string {
if name[0] == '.' {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(skip + 1)
return filepath.Join(filepath.Dir(callerFile), name[1:])
}
// load path relative to this directory
_, callerFile, _, _ := runtime.Caller(0)
return filepath.Join(filepath.Dir(callerFile), name)
}
func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
t := template.New("").Funcs(Funcs())
b, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
t, err = t.New(filepath.Base(filename)).Parse(string(b))
if err != nil {
panic(err)
}
buf := &bytes.Buffer{}
return buf, t.Execute(buf, tpldata)
}
func write(filename string, b []byte, packages *code.Packages) error {
err := os.MkdirAll(filepath.Dir(filename), 0o755)
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
formatted, err := imports.Prune(filename, b, packages)
if err != nil {
fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
formatted = b
}
err = os.WriteFile(filename, formatted, 0o644)
if err != nil {
return fmt.Errorf("failed to write %s: %w", filename, err)
}
return nil
}

View File

@@ -0,0 +1 @@
this is my test package

View File

@@ -0,0 +1 @@
this will not be included

View File

@@ -0,0 +1,32 @@
package codegen
import (
"fmt"
"github.com/99designs/gqlgen/codegen/config"
)
func (b *builder) buildTypes() map[string]*config.TypeReference {
ret := map[string]*config.TypeReference{}
for _, ref := range b.Binder.References {
processType(ret, ref)
}
return ret
}
func processType(ret map[string]*config.TypeReference, ref *config.TypeReference) {
key := ref.UniquenessKey()
if existing, found := ret[key]; found {
// Simplistic check of content which is obviously different.
existingGQL := fmt.Sprintf("%v", existing.GQL)
newGQL := fmt.Sprintf("%v", ref.GQL)
if existingGQL != newGQL {
panic(fmt.Sprintf("non-unique key \"%s\", trying to replace %s with %s", key, existingGQL, newGQL))
}
}
ret[key] = ref
if ref.IsSlice() || ref.IsPtrToSlice() || ref.IsPtrToPtr() || ref.IsPtrToIntf() {
processType(ret, ref.Elem())
}
}

View File

@@ -0,0 +1,194 @@
{{- range $type := .ReferencedTypes }}
{{ with $type.UnmarshalFunc }}
func (ec *executionContext) {{ . }}(ctx context.Context, v interface{}) ({{ $type.GO | ref }}, error) {
{{- if and $type.IsNilable (not $type.GQL.NonNull) (not $type.IsPtrToPtr) }}
if v == nil { return nil, nil }
{{- end }}
{{- if or $type.IsPtrToSlice $type.IsPtrToIntf }}
res, err := ec.{{ $type.Elem.UnmarshalFunc }}(ctx, v)
return &res, graphql.ErrorOnPath(ctx, err)
{{- else if $type.IsSlice }}
var vSlice []interface{}
if v != nil {
vSlice = graphql.CoerceList(v)
}
var err error
res := make([]{{$type.GO.Elem | ref}}, len(vSlice))
for i := range vSlice {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i))
res[i], err = ec.{{ $type.Elem.UnmarshalFunc }}(ctx, vSlice[i])
if err != nil {
return nil, err
}
}
return res, nil
{{- else if and $type.IsPtrToPtr (not $type.Unmarshaler) (not $type.IsMarshaler) }}
var pres {{ $type.Elem.GO | ref }}
if v != nil {
res, err := ec.{{ $type.Elem.UnmarshalFunc }}(ctx, v)
if err != nil {
return nil, graphql.ErrorOnPath(ctx, err)
}
pres = res
}
return &pres, nil
{{- else }}
{{- if $type.Unmarshaler }}
{{- if $type.CastType }}
{{- if $type.IsContext }}
tmp, err := {{ $type.Unmarshaler | call }}(ctx, v)
{{- else }}
tmp, err := {{ $type.Unmarshaler | call }}(v)
{{- end }}
{{- if and $type.IsNilable $type.Elem }}
res := {{ $type.Elem.GO | ref }}(tmp)
{{- else}}
res := {{ $type.GO | ref }}(tmp)
{{- end }}
{{- else}}
{{- if $type.IsContext }}
res, err := {{ $type.Unmarshaler | call }}(ctx, v)
{{- else }}
res, err := {{ $type.Unmarshaler | call }}(v)
{{- end }}
{{- end }}
{{- if and $type.IsTargetNilable (not $type.IsNilable) }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- else if eq ($type.GO | ref) "map[string]interface{}" }}
return v.(map[string]interface{}), nil
{{- else if $type.IsMarshaler }}
{{- if and $type.IsNilable $type.Elem }}
var res = new({{ $type.Elem.GO | ref }})
{{- else}}
var res {{ $type.GO | ref }}
{{- end }}
{{- if $type.IsContext }}
err := res.UnmarshalGQLContext(ctx, v)
{{- else }}
err := res.UnmarshalGQL(v)
{{- end }}
return res, graphql.ErrorOnPath(ctx, err)
{{- else }}
res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v)
{{- if and $type.IsNilable (not $type.PointersInUmarshalInput) }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else if and (not $type.IsNilable) $type.PointersInUmarshalInput }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else }}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- end }}
{{- end }}
}
{{- end }}
{{ with $type.MarshalFunc }}
func (ec *executionContext) {{ . }}(ctx context.Context, sel ast.SelectionSet, v {{ $type.GO | ref }}) graphql.Marshaler {
{{- if or $type.IsPtrToSlice $type.IsPtrToIntf }}
return ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, *v)
{{- else if $type.IsSlice }}
{{- if not $type.GQL.NonNull }}
if v == nil {
return graphql.Null
}
{{- end }}
ret := make(graphql.Array, len(v))
{{- if not $type.IsScalar }}
var wg sync.WaitGroup
isLen1 := len(v) == 1
if !isLen1 {
wg.Add(len(v))
}
{{- end }}
for i := range v {
{{- if not $type.IsScalar }}
i := i
fc := &graphql.FieldContext{
Index: &i,
Result: &v[i],
}
ctx := graphql.WithFieldContext(ctx, fc)
f := func(i int) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = nil
}
}()
if !isLen1 {
defer wg.Done()
}
ret[i] = ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, v[i])
}
if isLen1 {
f(i)
} else {
go f(i)
}
{{ else }}
ret[i] = ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, v[i])
{{- end }}
}
{{ if not $type.IsScalar }} wg.Wait() {{ end }}
{{ if $type.Elem.GQL.NonNull }}
for _, e := range ret {
if e == graphql.Null {
return graphql.Null
}
}
{{ end }}
return ret
{{- else if and $type.IsPtrToPtr (not $type.Unmarshaler) (not $type.IsMarshaler) }}
if v == nil {
return graphql.Null
}
return ec.{{ $type.Elem.MarshalFunc }}(ctx, sel, *v)
{{- else }}
{{- if $type.IsNilable }}
if v == nil {
{{- if $type.GQL.NonNull }}
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
}
{{- end }}
return graphql.Null
}
{{- end }}
{{- if $type.IsMarshaler }}
{{- if $type.IsContext }}
return graphql.WrapContextMarshaler(ctx, v)
{{- else }}
return v
{{- end }}
{{- else if $type.Marshaler }}
{{- $v := "v" }}
{{- if and $type.IsTargetNilable (not $type.IsNilable) }}
{{- $v = "&v" }}
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
{{- $v = "*v" }}
{{- end }}
res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }})
{{- if $type.GQL.NonNull }}
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
}
}
{{- end }}
{{- if $type.IsContext }}
return graphql.WrapContextMarshaler(ctx, res)
{{- else }}
return res
{{- end }}
{{- else }}
return ec._{{$type.Definition.Name}}(ctx, sel, {{ if not $type.IsNilable}}&{{end}} v)
{{- end }}
{{- end }}
}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,47 @@
package codegen
import (
"fmt"
"go/types"
"strings"
)
func findGoNamedType(def types.Type) (*types.Named, error) {
if def == nil {
return nil, nil
}
namedType, ok := def.(*types.Named)
if !ok {
return nil, fmt.Errorf("expected %s to be a named type, instead found %T\n", def.String(), def)
}
return namedType, nil
}
func findGoInterface(def types.Type) (*types.Interface, error) {
if def == nil {
return nil, nil
}
namedType, err := findGoNamedType(def)
if err != nil {
return nil, err
}
if namedType == nil {
return nil, nil
}
underlying, ok := namedType.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("expected %s to be a named interface, instead found %s", def.String(), namedType.String())
}
return underlying, nil
}
func equalFieldName(source, target string) bool {
source = strings.ReplaceAll(source, "_", "")
source = strings.ReplaceAll(source, ",omitempty", "")
target = strings.ReplaceAll(target, "_", "")
return strings.EqualFold(source, target)
}

View File

@@ -0,0 +1,109 @@
package complexity
import (
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/ast"
)
func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
walker := complexityWalker{
es: es,
schema: es.Schema(),
vars: vars,
}
return walker.selectionSetComplexity(op.SelectionSet)
}
type complexityWalker struct {
es graphql.ExecutableSchema
schema *ast.Schema
vars map[string]interface{}
}
func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
var complexity int
for _, selection := range selectionSet {
switch s := selection.(type) {
case *ast.Field:
fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
if fieldDefinition.Name == "__Schema" {
continue
}
var childComplexity int
switch fieldDefinition.Kind {
case ast.Object, ast.Interface, ast.Union:
childComplexity = cw.selectionSetComplexity(s.SelectionSet)
}
args := s.ArgumentMap(cw.vars)
var fieldComplexity int
if s.ObjectDefinition.Kind == ast.Interface {
fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
} else {
fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
}
complexity = safeAdd(complexity, fieldComplexity)
case *ast.FragmentSpread:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
case *ast.InlineFragment:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
}
}
return complexity
}
func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
// Interfaces don't have their own separate field costs, so they have to assume the worst case.
// We iterate over all implementors and choose the most expensive one.
maxComplexity := 0
implementors := cw.schema.GetPossibleTypes(def)
for _, t := range implementors {
fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
if fieldComplexity > maxComplexity {
maxComplexity = fieldComplexity
}
}
return maxComplexity
}
func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
return customComplexity
}
// default complexity calculation
return safeAdd(1, childComplexity)
}
const maxInt = int(^uint(0) >> 1)
// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
//
// Adding complexities with this function prevents attackers from intentionally
// overflowing the complexity calculation to allow overly-complex queries.
//
// It also helps mitigate the impact of custom complexities that accidentally
// return negative values.
func safeAdd(a, b int) int {
// Ignore negative operands.
if a < 0 {
if b < 0 {
return 1
}
return b
} else if b < 0 {
return a
}
c := a + b
if c < a {
// Set c to maximum integer instead of overflowing.
c = maxInt
}
return c
}

View File

@@ -0,0 +1,19 @@
package graphql
import (
"encoding/json"
"io"
)
func MarshalAny(v interface{}) Marshaler {
return WriterFunc(func(w io.Writer) {
err := json.NewEncoder(w).Encode(v)
if err != nil {
panic(err)
}
})
}
func UnmarshalAny(v interface{}) (interface{}, error) {
return v, nil
}

View File

@@ -0,0 +1,27 @@
package graphql
import (
"fmt"
"io"
"strings"
)
func MarshalBoolean(b bool) Marshaler {
if b {
return WriterFunc(func(w io.Writer) { w.Write(trueLit) })
}
return WriterFunc(func(w io.Writer) { w.Write(falseLit) })
}
func UnmarshalBoolean(v interface{}) (bool, error) {
switch v := v.(type) {
case string:
return strings.ToLower(v) == "true", nil
case int:
return v != 0, nil
case bool:
return v, nil
default:
return false, fmt.Errorf("%T is not a bool", v)
}
}

View File

@@ -0,0 +1,29 @@
package graphql
import "context"
// Cache is a shared store for APQ and query AST caching
type Cache interface {
// Get looks up a key's value from the cache.
Get(ctx context.Context, key string) (value interface{}, ok bool)
// Add adds a value to the cache.
Add(ctx context.Context, key string, value interface{})
}
// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests
type MapCache map[string]interface{}
// Get looks up a key's value from the cache.
func (m MapCache) Get(_ context.Context, key string) (value interface{}, ok bool) {
v, ok := m[key]
return v, ok
}
// Add adds a value to the cache.
func (m MapCache) Add(_ context.Context, key string, value interface{}) { m[key] = value }
type NoCache struct{}
func (n NoCache) Get(_ context.Context, _ string) (value interface{}, ok bool) { return nil, false }
func (n NoCache) Add(_ context.Context, _ string, _ interface{}) {}

View File

@@ -0,0 +1,56 @@
package graphql
import (
"encoding/json"
)
// CoerceList applies coercion from a single value to a list.
func CoerceList(v interface{}) []interface{} {
var vSlice []interface{}
if v != nil {
switch v := v.(type) {
case []interface{}:
// already a slice no coercion required
vSlice = v
case []string:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []json.Number:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []bool:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []map[string]interface{}:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []float64:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []float32:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int32:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
case []int64:
if len(v) > 0 {
vSlice = []interface{}{v[0]}
}
default:
vSlice = []interface{}{v}
}
}
return vSlice
}

View File

@@ -0,0 +1,113 @@
package graphql
import (
"context"
"time"
"github.com/vektah/gqlparser/v2/ast"
)
type key string
const resolverCtx key = "resolver_context"
// Deprecated: Use FieldContext instead
type ResolverContext = FieldContext
type FieldContext struct {
Parent *FieldContext
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The index of array in path.
Index *int
// The result object of resolver
Result interface{}
// IsMethod indicates if the resolver is a method
IsMethod bool
// IsResolver indicates if the field has a user-specified resolver
IsResolver bool
// Child allows getting a child FieldContext by its field collection description.
// Note that, the returned child FieldContext represents the context as it was
// before the execution of the field resolver. For example:
//
// srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (interface{}, error) {
// fc := graphql.GetFieldContext(ctx)
// op := graphql.GetOperationContext(ctx)
// collected := graphql.CollectFields(opCtx, fc.Field.Selections, []string{"User"})
//
// child, err := fc.Child(ctx, collected[0])
// if err != nil {
// return nil, err
// }
// fmt.Println("child context %q with args: %v", child.Field.Name, child.Args)
//
// return next(ctx)
// })
//
Child func(context.Context, CollectedField) (*FieldContext, error)
}
type FieldStats struct {
// When field execution started
Started time.Time
// When argument marshaling finished
ArgumentsCompleted time.Time
// When the field completed running all middleware. Not available inside field middleware!
Completed time.Time
}
func (r *FieldContext) Path() ast.Path {
var path ast.Path
for it := r; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, ast.PathIndex(*it.Index))
} else if it.Field.Field != nil {
path = append(path, ast.PathName(it.Field.Alias))
}
}
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}
return path
}
// Deprecated: Use GetFieldContext instead
func GetResolverContext(ctx context.Context) *ResolverContext {
return GetFieldContext(ctx)
}
func GetFieldContext(ctx context.Context) *FieldContext {
if val, ok := ctx.Value(resolverCtx).(*FieldContext); ok {
return val
}
return nil
}
func WithFieldContext(ctx context.Context, rc *FieldContext) context.Context {
rc.Parent = GetFieldContext(ctx)
return context.WithValue(ctx, resolverCtx, rc)
}
func equalPath(a ast.Path, b ast.Path) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
}

View File

@@ -0,0 +1,125 @@
package graphql
import (
"context"
"errors"
"net/http"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// Deprecated: Please update all references to OperationContext instead
type RequestContext = OperationContext
type OperationContext struct {
RawQuery string
Variables map[string]interface{}
OperationName string
Doc *ast.QueryDocument
Headers http.Header
Operation *ast.OperationDefinition
DisableIntrospection bool
RecoverFunc RecoverFunc
ResolverMiddleware FieldMiddleware
RootResolverMiddleware RootFieldMiddleware
Stats Stats
}
func (c *OperationContext) Validate(ctx context.Context) error {
if c.Doc == nil {
return errors.New("field 'Doc'is required")
}
if c.RawQuery == "" {
return errors.New("field 'RawQuery' is required")
}
if c.Variables == nil {
c.Variables = make(map[string]interface{})
}
if c.ResolverMiddleware == nil {
return errors.New("field 'ResolverMiddleware' is required")
}
if c.RootResolverMiddleware == nil {
return errors.New("field 'RootResolverMiddleware' is required")
}
if c.RecoverFunc == nil {
c.RecoverFunc = DefaultRecover
}
return nil
}
const operationCtx key = "operation_context"
// Deprecated: Please update all references to GetOperationContext instead
func GetRequestContext(ctx context.Context) *RequestContext {
return GetOperationContext(ctx)
}
func GetOperationContext(ctx context.Context) *OperationContext {
if val, ok := ctx.Value(operationCtx).(*OperationContext); ok && val != nil {
return val
}
panic("missing operation context")
}
func WithOperationContext(ctx context.Context, rc *OperationContext) context.Context {
return context.WithValue(ctx, operationCtx, rc)
}
// HasOperationContext checks if the given context is part of an ongoing operation
//
// Some errors can happen outside of an operation, eg json unmarshal errors.
func HasOperationContext(ctx context.Context) bool {
_, ok := ctx.Value(operationCtx).(*OperationContext)
return ok
}
// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
resctx := GetFieldContext(ctx)
return CollectFields(GetOperationContext(ctx), resctx.Field.Selections, satisfies)
}
// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
func CollectAllFields(ctx context.Context) []string {
resctx := GetFieldContext(ctx)
collected := CollectFields(GetOperationContext(ctx), resctx.Field.Selections, nil)
uniq := make([]string, 0, len(collected))
Next:
for _, f := range collected {
for _, name := range uniq {
if name == f.Name {
continue Next
}
}
uniq = append(uniq, f.Name)
}
return uniq
}
// Errorf sends an error string to the client, passing it through the formatter.
// Deprecated: use graphql.AddErrorf(ctx, err) instead
func (c *OperationContext) Errorf(ctx context.Context, format string, args ...interface{}) {
AddErrorf(ctx, format, args...)
}
// Error add error or multiple errors (if underlaying type is gqlerror.List) into the stack.
// Then it will be sends to the client, passing it through the formatter.
func (c *OperationContext) Error(ctx context.Context, err error) {
if errList, ok := err.(gqlerror.List); ok {
for _, e := range errList {
AddError(ctx, e)
}
return
}
AddError(ctx, err)
}
func (c *OperationContext) Recover(ctx context.Context, err interface{}) error {
return ErrorOnPath(ctx, c.RecoverFunc(ctx, err))
}

View File

@@ -0,0 +1,77 @@
package graphql
import (
"context"
"github.com/vektah/gqlparser/v2/ast"
)
const fieldInputCtx key = "path_context"
type PathContext struct {
ParentField *FieldContext
Parent *PathContext
Field *string
Index *int
}
func (fic *PathContext) Path() ast.Path {
var path ast.Path
for it := fic; it != nil; it = it.Parent {
if it.Index != nil {
path = append(path, ast.PathIndex(*it.Index))
} else if it.Field != nil {
path = append(path, ast.PathName(*it.Field))
}
}
// because we are walking up the chain, all the elements are backwards, do an inplace flip.
for i := len(path)/2 - 1; i >= 0; i-- {
opp := len(path) - 1 - i
path[i], path[opp] = path[opp], path[i]
}
if fic.ParentField != nil {
fieldPath := fic.ParentField.Path()
return append(fieldPath, path...)
}
return path
}
func NewPathWithField(field string) *PathContext {
return &PathContext{Field: &field}
}
func NewPathWithIndex(index int) *PathContext {
return &PathContext{Index: &index}
}
func WithPathContext(ctx context.Context, fic *PathContext) context.Context {
if fieldContext := GetFieldContext(ctx); fieldContext != nil {
fic.ParentField = fieldContext
}
if fieldInputContext := GetPathContext(ctx); fieldInputContext != nil {
fic.Parent = fieldInputContext
}
return context.WithValue(ctx, fieldInputCtx, fic)
}
func GetPathContext(ctx context.Context) *PathContext {
if val, ok := ctx.Value(fieldInputCtx).(*PathContext); ok {
return val
}
return nil
}
func GetPath(ctx context.Context) ast.Path {
if pc := GetPathContext(ctx); pc != nil {
return pc.Path()
}
if fc := GetFieldContext(ctx); fc != nil {
return fc.Path()
}
return nil
}

View File

@@ -0,0 +1,161 @@
package graphql
import (
"context"
"fmt"
"sync"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type responseContext struct {
errorPresenter ErrorPresenterFunc
recover RecoverFunc
errors gqlerror.List
errorsMu sync.Mutex
extensions map[string]interface{}
extensionsMu sync.Mutex
}
const resultCtx key = "result_context"
func getResponseContext(ctx context.Context) *responseContext {
val, ok := ctx.Value(resultCtx).(*responseContext)
if !ok {
panic("missing response context")
}
return val
}
func WithResponseContext(ctx context.Context, presenterFunc ErrorPresenterFunc, recoverFunc RecoverFunc) context.Context {
return context.WithValue(ctx, resultCtx, &responseContext{
errorPresenter: presenterFunc,
recover: recoverFunc,
})
}
func WithFreshResponseContext(ctx context.Context) context.Context {
e := getResponseContext(ctx)
return context.WithValue(ctx, resultCtx, &responseContext{
errorPresenter: e.errorPresenter,
recover: e.recover,
})
}
// AddErrorf writes a formatted error to the client, first passing it through the error presenter.
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
AddError(ctx, fmt.Errorf(format, args...))
}
// AddError sends an error to the client, first passing it through the error presenter.
func AddError(ctx context.Context, err error) {
c := getResponseContext(ctx)
presentedError := c.errorPresenter(ctx, ErrorOnPath(ctx, err))
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.errors = append(c.errors, presentedError)
}
func Recover(ctx context.Context, err interface{}) (userMessage error) {
c := getResponseContext(ctx)
return ErrorOnPath(ctx, c.recover(ctx, err))
}
// HasFieldError returns true if the given field has already errored
func HasFieldError(ctx context.Context, rctx *FieldContext) bool {
c := getResponseContext(ctx)
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
if len(c.errors) == 0 {
return false
}
path := rctx.Path()
for _, err := range c.errors {
if equalPath(err.Path, path) {
return true
}
}
return false
}
// GetFieldErrors returns a list of errors that occurred in the given field
func GetFieldErrors(ctx context.Context, rctx *FieldContext) gqlerror.List {
c := getResponseContext(ctx)
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
if len(c.errors) == 0 {
return nil
}
path := rctx.Path()
var errs gqlerror.List
for _, err := range c.errors {
if equalPath(err.Path, path) {
errs = append(errs, err)
}
}
return errs
}
func GetErrors(ctx context.Context) gqlerror.List {
resCtx := getResponseContext(ctx)
resCtx.errorsMu.Lock()
defer resCtx.errorsMu.Unlock()
if len(resCtx.errors) == 0 {
return nil
}
errs := resCtx.errors
cpy := make(gqlerror.List, len(errs))
for i := range errs {
errCpy := *errs[i]
cpy[i] = &errCpy
}
return cpy
}
// RegisterExtension allows you to add a new extension into the graphql response
func RegisterExtension(ctx context.Context, key string, value interface{}) {
c := getResponseContext(ctx)
c.extensionsMu.Lock()
defer c.extensionsMu.Unlock()
if c.extensions == nil {
c.extensions = make(map[string]interface{})
}
if _, ok := c.extensions[key]; ok {
panic(fmt.Errorf("extension already registered for key %s", key))
}
c.extensions[key] = value
}
// GetExtensions returns any extensions registered in the current result context
func GetExtensions(ctx context.Context) map[string]interface{} {
ext := getResponseContext(ctx).extensions
if ext == nil {
return map[string]interface{}{}
}
return ext
}
func GetExtension(ctx context.Context, name string) interface{} {
ext := getResponseContext(ctx).extensions
if ext == nil {
return nil
}
return ext[name]
}

View File

@@ -0,0 +1,25 @@
package graphql
import (
"context"
)
const rootResolverCtx key = "root_resolver_context"
type RootFieldContext struct {
// The name of the type this field belongs to
Object string
// The raw field
Field CollectedField
}
func GetRootFieldContext(ctx context.Context) *RootFieldContext {
if val, ok := ctx.Value(rootResolverCtx).(*RootFieldContext); ok {
return val
}
return nil
}
func WithRootFieldContext(ctx context.Context, rc *RootFieldContext) context.Context {
return context.WithValue(ctx, rootResolverCtx, rc)
}

View File

@@ -0,0 +1,26 @@
package graphql
import (
"context"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type Deferrable struct {
Label string
}
type DeferredGroup struct {
Path ast.Path
Label string
FieldSet *FieldSet
Context context.Context
}
type DeferredResult struct {
Path ast.Path
Label string
Result Marshaler
Errors gqlerror.List
}

View File

@@ -0,0 +1,60 @@
package errcode
import (
"github.com/vektah/gqlparser/v2/gqlerror"
)
const (
ValidationFailed = "GRAPHQL_VALIDATION_FAILED"
ParseFailed = "GRAPHQL_PARSE_FAILED"
)
type ErrorKind int
const (
// issues with graphql (validation, parsing). 422s in http, GQL_ERROR in websocket
KindProtocol ErrorKind = iota
// user errors, 200s in http, GQL_DATA in websocket
KindUser
)
var codeType = map[string]ErrorKind{
ValidationFailed: KindProtocol,
ParseFailed: KindProtocol,
}
// RegisterErrorType should be called by extensions that want to customize the http status codes for
// errors they return
func RegisterErrorType(code string, kind ErrorKind) {
codeType[code] = kind
}
// Set the error code on a given graphql error extension
func Set(err error, value string) {
if err == nil {
return
}
gqlErr, ok := err.(*gqlerror.Error)
if !ok {
return
}
if gqlErr.Extensions == nil {
gqlErr.Extensions = map[string]interface{}{}
}
gqlErr.Extensions["code"] = value
}
// get the kind of the first non User error, defaults to User if no errors have a custom extension
func GetErrorKind(errs gqlerror.List) ErrorKind {
for _, err := range errs {
if code, ok := err.Extensions["code"].(string); ok {
if kind, ok := codeType[code]; ok && kind != KindUser {
return kind
}
}
}
return KindUser
}

View File

@@ -0,0 +1,33 @@
package graphql
import (
"context"
"errors"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type ErrorPresenterFunc func(ctx context.Context, err error) *gqlerror.Error
func DefaultErrorPresenter(ctx context.Context, err error) *gqlerror.Error {
var gqlErr *gqlerror.Error
if errors.As(err, &gqlErr) {
return gqlErr
}
return gqlerror.WrapPath(GetPath(ctx), err)
}
func ErrorOnPath(ctx context.Context, err error) error {
if err == nil {
return nil
}
var gqlErr *gqlerror.Error
if errors.As(err, &gqlErr) {
if gqlErr.Path == nil {
gqlErr.Path = GetPath(ctx)
}
// Return the original error to avoid losing any attached annotation
return err
}
return gqlerror.WrapPath(GetPath(ctx), err)
}

View File

@@ -0,0 +1,211 @@
//go:generate go run github.com/matryer/moq -out executable_schema_mock.go . ExecutableSchema
package graphql
import (
"context"
"fmt"
"github.com/vektah/gqlparser/v2/ast"
)
type ExecutableSchema interface {
Schema() *ast.Schema
Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
Exec(ctx context.Context) ResponseHandler
}
// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types
// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment
// type conditions.
func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField {
return collectFields(reqCtx, selSet, satisfies, map[string]bool{})
}
func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField {
groupedFields := make([]CollectedField, 0, len(selSet))
for _, sel := range selSet {
switch sel := sel.(type) {
case *ast.Field:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField {
return CollectedField{Field: sel}
})
f.Selections = append(f.Selections, sel.SelectionSet...)
case *ast.InlineFragment:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) {
continue
}
shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)
for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(
&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition,
func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
if shouldDefer {
f.Deferrable = &Deferrable{
Label: label,
}
}
}
case *ast.FragmentSpread:
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) {
continue
}
fragmentName := sel.Name
if _, seen := visited[fragmentName]; seen {
continue
}
visited[fragmentName] = true
fragment := reqCtx.Doc.Fragments.ForName(fragmentName)
if fragment == nil {
// should never happen, validator has already run
panic(fmt.Errorf("missing fragment %s", fragmentName))
}
if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) {
continue
}
shouldDefer, label := deferrable(sel.Directives, reqCtx.Variables)
for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) {
f := getOrCreateAndAppendField(&groupedFields,
childField.Name, childField.Alias, childField.ObjectDefinition,
func() CollectedField { return childField })
f.Selections = append(f.Selections, childField.Selections...)
if shouldDefer {
f.Deferrable = &Deferrable{Label: label}
}
}
default:
panic(fmt.Errorf("unsupported %T", sel))
}
}
return groupedFields
}
type CollectedField struct {
*ast.Field
Selections ast.SelectionSet
Deferrable *Deferrable
}
func instanceOf(val string, satisfies []string) bool {
for _, s := range satisfies {
if val == s {
return true
}
}
return false
}
func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField {
for i, cf := range *c {
if cf.Name == name && cf.Alias == alias {
if cf.ObjectDefinition == objectDefinition {
return &(*c)[i]
}
if cf.ObjectDefinition == nil || objectDefinition == nil {
continue
}
if cf.ObjectDefinition.Name == objectDefinition.Name {
return &(*c)[i]
}
for _, ifc := range objectDefinition.Interfaces {
if ifc == cf.ObjectDefinition.Name {
return &(*c)[i]
}
}
for _, ifc := range cf.ObjectDefinition.Interfaces {
if ifc == objectDefinition.Name {
return &(*c)[i]
}
}
}
}
f := creator()
*c = append(*c, f)
return &(*c)[len(*c)-1]
}
func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool {
if len(directives) == 0 {
return true
}
skip, include := false, true
if d := directives.ForName("skip"); d != nil {
skip = resolveIfArgument(d, variables)
}
if d := directives.ForName("include"); d != nil {
include = resolveIfArgument(d, variables)
}
return !skip && include
}
func deferrable(directives ast.DirectiveList, variables map[string]interface{}) (shouldDefer bool, label string) {
d := directives.ForName("defer")
if d == nil {
return false, ""
}
shouldDefer = true
for _, arg := range d.Arguments {
switch arg.Name {
case "if":
if value, err := arg.Value.Value(variables); err == nil {
shouldDefer, _ = value.(bool)
}
case "label":
if value, err := arg.Value.Value(variables); err == nil {
label, _ = value.(string)
}
default:
panic(fmt.Sprintf("defer: argument '%s' not supported", arg.Name))
}
}
return shouldDefer, label
}
func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool {
arg := d.Arguments.ForName("if")
if arg == nil {
panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name))
}
value, err := arg.Value.Value(variables)
if err != nil {
panic(err)
}
ret, ok := value.(bool)
if !ok {
panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name))
}
return ret
}

View File

@@ -0,0 +1,175 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package graphql
import (
"context"
"github.com/vektah/gqlparser/v2/ast"
"sync"
)
// Ensure, that ExecutableSchemaMock does implement ExecutableSchema.
// If this is not the case, regenerate this file with moq.
var _ ExecutableSchema = &ExecutableSchemaMock{}
// ExecutableSchemaMock is a mock implementation of ExecutableSchema.
//
// func TestSomethingThatUsesExecutableSchema(t *testing.T) {
//
// // make and configure a mocked ExecutableSchema
// mockedExecutableSchema := &ExecutableSchemaMock{
// ComplexityFunc: func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
// panic("mock out the Complexity method")
// },
// ExecFunc: func(ctx context.Context) ResponseHandler {
// panic("mock out the Exec method")
// },
// SchemaFunc: func() *ast.Schema {
// panic("mock out the Schema method")
// },
// }
//
// // use mockedExecutableSchema in code that requires ExecutableSchema
// // and then make assertions.
//
// }
type ExecutableSchemaMock struct {
// ComplexityFunc mocks the Complexity method.
ComplexityFunc func(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
// ExecFunc mocks the Exec method.
ExecFunc func(ctx context.Context) ResponseHandler
// SchemaFunc mocks the Schema method.
SchemaFunc func() *ast.Schema
// calls tracks calls to the methods.
calls struct {
// Complexity holds details about calls to the Complexity method.
Complexity []struct {
// TypeName is the typeName argument value.
TypeName string
// FieldName is the fieldName argument value.
FieldName string
// ChildComplexity is the childComplexity argument value.
ChildComplexity int
// Args is the args argument value.
Args map[string]interface{}
}
// Exec holds details about calls to the Exec method.
Exec []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// Schema holds details about calls to the Schema method.
Schema []struct {
}
}
lockComplexity sync.RWMutex
lockExec sync.RWMutex
lockSchema sync.RWMutex
}
// Complexity calls ComplexityFunc.
func (mock *ExecutableSchemaMock) Complexity(typeName string, fieldName string, childComplexity int, args map[string]interface{}) (int, bool) {
if mock.ComplexityFunc == nil {
panic("ExecutableSchemaMock.ComplexityFunc: method is nil but ExecutableSchema.Complexity was just called")
}
callInfo := struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
}{
TypeName: typeName,
FieldName: fieldName,
ChildComplexity: childComplexity,
Args: args,
}
mock.lockComplexity.Lock()
mock.calls.Complexity = append(mock.calls.Complexity, callInfo)
mock.lockComplexity.Unlock()
return mock.ComplexityFunc(typeName, fieldName, childComplexity, args)
}
// ComplexityCalls gets all the calls that were made to Complexity.
// Check the length with:
//
// len(mockedExecutableSchema.ComplexityCalls())
func (mock *ExecutableSchemaMock) ComplexityCalls() []struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
} {
var calls []struct {
TypeName string
FieldName string
ChildComplexity int
Args map[string]interface{}
}
mock.lockComplexity.RLock()
calls = mock.calls.Complexity
mock.lockComplexity.RUnlock()
return calls
}
// Exec calls ExecFunc.
func (mock *ExecutableSchemaMock) Exec(ctx context.Context) ResponseHandler {
if mock.ExecFunc == nil {
panic("ExecutableSchemaMock.ExecFunc: method is nil but ExecutableSchema.Exec was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockExec.Lock()
mock.calls.Exec = append(mock.calls.Exec, callInfo)
mock.lockExec.Unlock()
return mock.ExecFunc(ctx)
}
// ExecCalls gets all the calls that were made to Exec.
// Check the length with:
//
// len(mockedExecutableSchema.ExecCalls())
func (mock *ExecutableSchemaMock) ExecCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockExec.RLock()
calls = mock.calls.Exec
mock.lockExec.RUnlock()
return calls
}
// Schema calls SchemaFunc.
func (mock *ExecutableSchemaMock) Schema() *ast.Schema {
if mock.SchemaFunc == nil {
panic("ExecutableSchemaMock.SchemaFunc: method is nil but ExecutableSchema.Schema was just called")
}
callInfo := struct {
}{}
mock.lockSchema.Lock()
mock.calls.Schema = append(mock.calls.Schema, callInfo)
mock.lockSchema.Unlock()
return mock.SchemaFunc()
}
// SchemaCalls gets all the calls that were made to Schema.
// Check the length with:
//
// len(mockedExecutableSchema.SchemaCalls())
func (mock *ExecutableSchemaMock) SchemaCalls() []struct {
} {
var calls []struct {
}
mock.lockSchema.RLock()
calls = mock.calls.Schema
mock.lockSchema.RUnlock()
return calls
}

View File

@@ -0,0 +1,221 @@
package executor
import (
"context"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/vektah/gqlparser/v2/parser"
"github.com/vektah/gqlparser/v2/validator"
)
// Executor executes graphql queries against a schema.
type Executor struct {
es graphql.ExecutableSchema
extensions []graphql.HandlerExtension
ext extensions
errorPresenter graphql.ErrorPresenterFunc
recoverFunc graphql.RecoverFunc
queryCache graphql.Cache
}
var _ graphql.GraphExecutor = &Executor{}
// New creates a new Executor with the given schema, and a default error and
// recovery callbacks, and no query cache or extensions.
func New(es graphql.ExecutableSchema) *Executor {
e := &Executor{
es: es,
errorPresenter: graphql.DefaultErrorPresenter,
recoverFunc: graphql.DefaultRecover,
queryCache: graphql.NoCache{},
ext: processExtensions(nil),
}
return e
}
func (e *Executor) CreateOperationContext(
ctx context.Context,
params *graphql.RawParams,
) (*graphql.OperationContext, gqlerror.List) {
rc := &graphql.OperationContext{
DisableIntrospection: true,
RecoverFunc: e.recoverFunc,
ResolverMiddleware: e.ext.fieldMiddleware,
RootResolverMiddleware: e.ext.rootFieldMiddleware,
Stats: graphql.Stats{
Read: params.ReadTime,
OperationStart: graphql.GetStartTime(ctx),
},
}
ctx = graphql.WithOperationContext(ctx, rc)
for _, p := range e.ext.operationParameterMutators {
if err := p.MutateOperationParameters(ctx, params); err != nil {
return rc, gqlerror.List{err}
}
}
rc.RawQuery = params.Query
rc.OperationName = params.OperationName
rc.Headers = params.Headers
var listErr gqlerror.List
rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query)
if len(listErr) != 0 {
return rc, listErr
}
rc.Operation = rc.Doc.Operations.ForName(params.OperationName)
if rc.Operation == nil {
err := gqlerror.Errorf("operation %s not found", params.OperationName)
errcode.Set(err, errcode.ValidationFailed)
return rc, gqlerror.List{err}
}
var err error
rc.Variables, err = validator.VariableValues(e.es.Schema(), rc.Operation, params.Variables)
if err != nil {
gqlErr, ok := err.(*gqlerror.Error)
if ok {
errcode.Set(gqlErr, errcode.ValidationFailed)
return rc, gqlerror.List{gqlErr}
}
}
rc.Stats.Validation.End = graphql.Now()
for _, p := range e.ext.operationContextMutators {
if err := p.MutateOperationContext(ctx, rc); err != nil {
return rc, gqlerror.List{err}
}
}
return rc, nil
}
func (e *Executor) DispatchOperation(
ctx context.Context,
rc *graphql.OperationContext,
) (graphql.ResponseHandler, context.Context) {
ctx = graphql.WithOperationContext(ctx, rc)
var innerCtx context.Context
res := e.ext.operationMiddleware(ctx, func(ctx context.Context) graphql.ResponseHandler {
innerCtx = ctx
tmpResponseContext := graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
responses := e.es.Exec(tmpResponseContext)
if errs := graphql.GetErrors(tmpResponseContext); errs != nil {
return graphql.OneShot(&graphql.Response{Errors: errs})
}
return func(ctx context.Context) *graphql.Response {
ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := responses(ctx)
if resp == nil {
return nil
}
resp.Errors = append(resp.Errors, graphql.GetErrors(ctx)...)
resp.Extensions = graphql.GetExtensions(ctx)
return resp
})
if resp == nil {
return nil
}
return resp
}
})
return res, innerCtx
}
func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graphql.Response {
ctx = graphql.WithResponseContext(ctx, e.errorPresenter, e.recoverFunc)
for _, gErr := range list {
graphql.AddError(ctx, gErr)
}
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := &graphql.Response{
Errors: graphql.GetErrors(ctx),
}
resp.Extensions = graphql.GetExtensions(ctx)
return resp
})
return resp
}
func (e *Executor) PresentRecoveredError(ctx context.Context, err interface{}) error {
return e.errorPresenter(ctx, e.recoverFunc(ctx, err))
}
func (e *Executor) SetQueryCache(cache graphql.Cache) {
e.queryCache = cache
}
func (e *Executor) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
e.errorPresenter = f
}
func (e *Executor) SetRecoverFunc(f graphql.RecoverFunc) {
e.recoverFunc = f
}
// parseQuery decodes the incoming query and validates it, pulling from cache if present.
//
// NOTE: This should NOT look at variables, they will change per request. It should only parse and
// validate
// the raw query string.
func (e *Executor) parseQuery(
ctx context.Context,
stats *graphql.Stats,
query string,
) (*ast.QueryDocument, gqlerror.List) {
stats.Parsing.Start = graphql.Now()
if doc, ok := e.queryCache.Get(ctx, query); ok {
now := graphql.Now()
stats.Parsing.End = now
stats.Validation.Start = now
return doc.(*ast.QueryDocument), nil
}
doc, err := parser.ParseQuery(&ast.Source{Input: query})
if err != nil {
gqlErr, ok := err.(*gqlerror.Error)
if ok {
errcode.Set(gqlErr, errcode.ParseFailed)
return nil, gqlerror.List{gqlErr}
}
}
stats.Parsing.End = graphql.Now()
stats.Validation.Start = graphql.Now()
if len(doc.Operations) == 0 {
err = gqlerror.Errorf("no operation provided")
gqlErr, _ := err.(*gqlerror.Error)
errcode.Set(err, errcode.ValidationFailed)
return nil, gqlerror.List{gqlErr}
}
listErr := validator.Validate(e.es.Schema(), doc)
if len(listErr) != 0 {
for _, e := range listErr {
errcode.Set(e, errcode.ValidationFailed)
}
return nil, listErr
}
e.queryCache.Add(ctx, query, doc)
return doc, nil
}

View File

@@ -0,0 +1,195 @@
package executor
import (
"context"
"fmt"
"github.com/99designs/gqlgen/graphql"
)
// Use adds the given extension to this Executor.
func (e *Executor) Use(extension graphql.HandlerExtension) {
if err := extension.Validate(e.es); err != nil {
panic(err)
}
switch extension.(type) {
case graphql.OperationParameterMutator,
graphql.OperationContextMutator,
graphql.OperationInterceptor,
graphql.RootFieldInterceptor,
graphql.FieldInterceptor,
graphql.ResponseInterceptor:
e.extensions = append(e.extensions, extension)
e.ext = processExtensions(e.extensions)
default:
panic(fmt.Errorf("cannot Use %T as a gqlgen handler extension because it does not implement any extension hooks", extension))
}
}
// AroundFields is a convenience method for creating an extension that only implements field middleware
func (e *Executor) AroundFields(f graphql.FieldMiddleware) {
e.Use(aroundFieldFunc(f))
}
// AroundRootFields is a convenience method for creating an extension that only implements root field middleware
func (e *Executor) AroundRootFields(f graphql.RootFieldMiddleware) {
e.Use(aroundRootFieldFunc(f))
}
// AroundOperations is a convenience method for creating an extension that only implements operation middleware
func (e *Executor) AroundOperations(f graphql.OperationMiddleware) {
e.Use(aroundOpFunc(f))
}
// AroundResponses is a convenience method for creating an extension that only implements response middleware
func (e *Executor) AroundResponses(f graphql.ResponseMiddleware) {
e.Use(aroundRespFunc(f))
}
type extensions struct {
operationMiddleware graphql.OperationMiddleware
responseMiddleware graphql.ResponseMiddleware
rootFieldMiddleware graphql.RootFieldMiddleware
fieldMiddleware graphql.FieldMiddleware
operationParameterMutators []graphql.OperationParameterMutator
operationContextMutators []graphql.OperationContextMutator
}
func processExtensions(exts []graphql.HandlerExtension) extensions {
e := extensions{
operationMiddleware: func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return next(ctx)
},
responseMiddleware: func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return next(ctx)
},
rootFieldMiddleware: func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return next(ctx)
},
fieldMiddleware: func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return next(ctx)
},
}
// this loop goes backwards so the first extension is the outer most middleware and runs first.
for i := len(exts) - 1; i >= 0; i-- {
p := exts[i]
if p, ok := p.(graphql.OperationInterceptor); ok {
previous := e.operationMiddleware
e.operationMiddleware = func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return p.InterceptOperation(ctx, func(ctx context.Context) graphql.ResponseHandler {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.ResponseInterceptor); ok {
previous := e.responseMiddleware
e.responseMiddleware = func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return p.InterceptResponse(ctx, func(ctx context.Context) *graphql.Response {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.RootFieldInterceptor); ok {
previous := e.rootFieldMiddleware
e.rootFieldMiddleware = func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return p.InterceptRootField(ctx, func(ctx context.Context) graphql.Marshaler {
return previous(ctx, next)
})
}
}
if p, ok := p.(graphql.FieldInterceptor); ok {
previous := e.fieldMiddleware
e.fieldMiddleware = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return p.InterceptField(ctx, func(ctx context.Context) (res interface{}, err error) {
return previous(ctx, next)
})
}
}
}
for _, p := range exts {
if p, ok := p.(graphql.OperationParameterMutator); ok {
e.operationParameterMutators = append(e.operationParameterMutators, p)
}
if p, ok := p.(graphql.OperationContextMutator); ok {
e.operationContextMutators = append(e.operationContextMutators, p)
}
}
return e
}
type aroundOpFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
func (r aroundOpFunc) ExtensionName() string {
return "InlineOperationFunc"
}
func (r aroundOpFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("OperationFunc can not be nil")
}
return nil
}
func (r aroundOpFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}
type aroundRespFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
func (r aroundRespFunc) ExtensionName() string {
return "InlineResponseFunc"
}
func (r aroundRespFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("ResponseFunc can not be nil")
}
return nil
}
func (r aroundRespFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return r(ctx, next)
}
type aroundFieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
func (f aroundFieldFunc) ExtensionName() string {
return "InlineFieldFunc"
}
func (f aroundFieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("FieldFunc can not be nil")
}
return nil
}
func (f aroundFieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}
type aroundRootFieldFunc func(ctx context.Context, next graphql.RootResolver) graphql.Marshaler
func (f aroundRootFieldFunc) ExtensionName() string {
return "InlineRootFieldFunc"
}
func (f aroundRootFieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("RootFieldFunc can not be nil")
}
return nil
}
func (f aroundRootFieldFunc) InterceptRootField(ctx context.Context, next graphql.RootResolver) graphql.Marshaler {
return f(ctx, next)
}

View File

@@ -0,0 +1,70 @@
package graphql
import (
"context"
"io"
"sync"
)
type FieldSet struct {
fields []CollectedField
Values []Marshaler
Invalids uint32
delayed []delayedResult
}
type delayedResult struct {
i int
f func(context.Context) Marshaler
}
func NewFieldSet(fields []CollectedField) *FieldSet {
return &FieldSet{
fields: fields,
Values: make([]Marshaler, len(fields)),
}
}
func (m *FieldSet) AddField(field CollectedField) {
m.fields = append(m.fields, field)
m.Values = append(m.Values, nil)
}
func (m *FieldSet) Concurrently(i int, f func(context.Context) Marshaler) {
m.delayed = append(m.delayed, delayedResult{i: i, f: f})
}
func (m *FieldSet) Dispatch(ctx context.Context) {
if len(m.delayed) == 1 {
// only one concurrent task, no need to spawn a goroutine or deal create waitgroups
d := m.delayed[0]
m.Values[d.i] = d.f(ctx)
} else if len(m.delayed) > 1 {
// more than one concurrent task, use the main goroutine to do one, only spawn goroutines for the others
var wg sync.WaitGroup
for _, d := range m.delayed[1:] {
wg.Add(1)
go func(d delayedResult) {
m.Values[d.i] = d.f(ctx)
wg.Done()
}(d)
}
m.Values[m.delayed[0].i] = m.delayed[0].f(ctx)
wg.Wait()
}
}
func (m *FieldSet) MarshalGQL(writer io.Writer) {
writer.Write(openBrace)
for i, field := range m.fields {
if i != 0 {
writer.Write(comma)
}
writeQuotedString(writer, field.Alias)
writer.Write(colon)
m.Values[i].MarshalGQL(writer)
}
writer.Write(closeBrace)
}

View File

@@ -0,0 +1,47 @@
package graphql
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"strconv"
)
func MarshalFloat(f float64) Marshaler {
return WriterFunc(func(w io.Writer) {
io.WriteString(w, fmt.Sprintf("%g", f))
})
}
func UnmarshalFloat(v interface{}) (float64, error) {
switch v := v.(type) {
case string:
return strconv.ParseFloat(v, 64)
case int:
return float64(v), nil
case int64:
return float64(v), nil
case float64:
return v, nil
case json.Number:
return strconv.ParseFloat(string(v), 64)
default:
return 0, fmt.Errorf("%T is not an float", v)
}
}
func MarshalFloatContext(f float64) ContextMarshaler {
return ContextWriterFunc(func(ctx context.Context, w io.Writer) error {
if math.IsInf(f, 0) || math.IsNaN(f) {
return fmt.Errorf("cannot marshal infinite no NaN float values")
}
io.WriteString(w, fmt.Sprintf("%g", f))
return nil
})
}
func UnmarshalFloatContext(ctx context.Context, v interface{}) (float64, error) {
return UnmarshalFloat(v)
}

View File

@@ -0,0 +1,131 @@
package graphql
import (
"context"
"net/http"
"strconv"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
OperationMiddleware func(ctx context.Context, next OperationHandler) ResponseHandler
OperationHandler func(ctx context.Context) ResponseHandler
ResponseHandler func(ctx context.Context) *Response
ResponseMiddleware func(ctx context.Context, next ResponseHandler) *Response
Resolver func(ctx context.Context) (res interface{}, err error)
FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
RootResolver func(ctx context.Context) Marshaler
RootFieldMiddleware func(ctx context.Context, next RootResolver) Marshaler
RawParams struct {
Query string `json:"query"`
OperationName string `json:"operationName"`
Variables map[string]interface{} `json:"variables"`
Extensions map[string]interface{} `json:"extensions"`
Headers http.Header `json:"headers"`
ReadTime TraceTiming `json:"-"`
}
GraphExecutor interface {
CreateOperationContext(ctx context.Context, params *RawParams) (*OperationContext, gqlerror.List)
DispatchOperation(ctx context.Context, rc *OperationContext) (ResponseHandler, context.Context)
DispatchError(ctx context.Context, list gqlerror.List) *Response
}
// HandlerExtension adds functionality to the http handler. See the list of possible hook points below
// Its important to understand the lifecycle of a graphql request and the terminology we use in gqlgen
// before working with these
//
// +--- REQUEST POST /graphql --------------------------------------------+
// | +- OPERATION query OpName { viewer { name } } -----------------------+ |
// | | RESPONSE { "data": { "viewer": { "name": "bob" } } } | |
// | +- OPERATION subscription OpName2 { chat { message } } --------------+ |
// | | RESPONSE { "data": { "chat": { "message": "hello" } } } | |
// | | RESPONSE { "data": { "chat": { "message": "byee" } } } | |
// | +--------------------------------------------------------------------+ |
// +------------------------------------------------------------------------+
HandlerExtension interface {
// ExtensionName should be a CamelCase string version of the extension which may be shown in stats and logging.
ExtensionName() string
// Validate is called when adding an extension to the server, it allows validation against the servers schema.
Validate(schema ExecutableSchema) error
}
// OperationParameterMutator is called before creating a request context. allows manipulating the raw query
// on the way in.
OperationParameterMutator interface {
MutateOperationParameters(ctx context.Context, request *RawParams) *gqlerror.Error
}
// OperationContextMutator is called after creating the request context, but before executing the root resolver.
OperationContextMutator interface {
MutateOperationContext(ctx context.Context, rc *OperationContext) *gqlerror.Error
}
// OperationInterceptor is called for each incoming query, for basic requests the writer will be invoked once,
// for subscriptions it will be invoked multiple times.
OperationInterceptor interface {
InterceptOperation(ctx context.Context, next OperationHandler) ResponseHandler
}
// ResponseInterceptor is called around each graphql operation response. This can be called many times for a single
// operation the case of subscriptions.
ResponseInterceptor interface {
InterceptResponse(ctx context.Context, next ResponseHandler) *Response
}
RootFieldInterceptor interface {
InterceptRootField(ctx context.Context, next RootResolver) Marshaler
}
// FieldInterceptor called around each field
FieldInterceptor interface {
InterceptField(ctx context.Context, next Resolver) (res interface{}, err error)
}
// Transport provides support for different wire level encodings of graphql requests, eg Form, Get, Post, Websocket
Transport interface {
Supports(r *http.Request) bool
Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor)
}
)
type Status int
func (p *RawParams) AddUpload(upload Upload, key, path string) *gqlerror.Error {
if !strings.HasPrefix(path, "variables.") {
return gqlerror.Errorf("invalid operations paths for key %s", key)
}
var ptr interface{} = p.Variables
parts := strings.Split(path, ".")
// skip the first part (variables) because we started there
for i, p := range parts[1:] {
last := i == len(parts)-2
if ptr == nil {
return gqlerror.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path)
}
if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil {
if last {
ptr.([]interface{})[index] = upload
} else {
ptr = ptr.([]interface{})[index]
}
} else {
if last {
ptr.(map[string]interface{})[p] = upload
} else {
ptr = ptr.(map[string]interface{})[p]
}
}
}
return nil
}

View File

@@ -0,0 +1,114 @@
package extension
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
"github.com/mitchellh/mapstructure"
)
const (
errPersistedQueryNotFound = "PersistedQueryNotFound"
errPersistedQueryNotFoundCode = "PERSISTED_QUERY_NOT_FOUND"
)
// AutomaticPersistedQuery saves client upload by optimistically sending only the hashes of queries, if the server
// does not yet know what the query is for the hash it will respond telling the client to send the query along with the
// hash in the next request.
// see https://github.com/apollographql/apollo-link-persisted-queries
type AutomaticPersistedQuery struct {
Cache graphql.Cache
}
type ApqStats struct {
// The hash of the incoming query
Hash string
// SentQuery is true if the incoming request sent the full query
SentQuery bool
}
const apqExtension = "APQ"
var _ interface {
graphql.OperationParameterMutator
graphql.HandlerExtension
} = AutomaticPersistedQuery{}
func (a AutomaticPersistedQuery) ExtensionName() string {
return "AutomaticPersistedQuery"
}
func (a AutomaticPersistedQuery) Validate(schema graphql.ExecutableSchema) error {
if a.Cache == nil {
return fmt.Errorf("AutomaticPersistedQuery.Cache can not be nil")
}
return nil
}
func (a AutomaticPersistedQuery) MutateOperationParameters(ctx context.Context, rawParams *graphql.RawParams) *gqlerror.Error {
if rawParams.Extensions["persistedQuery"] == nil {
return nil
}
var extension struct {
Sha256 string `mapstructure:"sha256Hash"`
Version int64 `mapstructure:"version"`
}
if err := mapstructure.Decode(rawParams.Extensions["persistedQuery"], &extension); err != nil {
return gqlerror.Errorf("invalid APQ extension data")
}
if extension.Version != 1 {
return gqlerror.Errorf("unsupported APQ version")
}
fullQuery := false
if rawParams.Query == "" {
// client sent optimistic query hash without query string, get it from the cache
query, ok := a.Cache.Get(ctx, extension.Sha256)
if !ok {
err := gqlerror.Errorf(errPersistedQueryNotFound)
errcode.Set(err, errPersistedQueryNotFoundCode)
return err
}
rawParams.Query = query.(string)
} else {
// client sent optimistic query hash with query string, verify and store it
if computeQueryHash(rawParams.Query) != extension.Sha256 {
return gqlerror.Errorf("provided APQ hash does not match query")
}
a.Cache.Add(ctx, extension.Sha256, rawParams.Query)
fullQuery = true
}
graphql.GetOperationContext(ctx).Stats.SetExtension(apqExtension, &ApqStats{
Hash: extension.Sha256,
SentQuery: fullQuery,
})
return nil
}
func GetApqStats(ctx context.Context) *ApqStats {
rc := graphql.GetOperationContext(ctx)
if rc == nil {
return nil
}
s, _ := rc.Stats.GetExtension(apqExtension).(*ApqStats)
return s
}
func computeQueryHash(query string) string {
b := sha256.Sum256([]byte(query))
return hex.EncodeToString(b[:])
}

View File

@@ -0,0 +1,88 @@
package extension
import (
"context"
"fmt"
"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/gqlerror"
)
const errComplexityLimit = "COMPLEXITY_LIMIT_EXCEEDED"
// ComplexityLimit allows you to define a limit on query complexity
//
// If a query is submitted that exceeds the limit, a 422 status code will be returned.
type ComplexityLimit struct {
Func func(ctx context.Context, rc *graphql.OperationContext) int
es graphql.ExecutableSchema
}
var _ interface {
graphql.OperationContextMutator
graphql.HandlerExtension
} = &ComplexityLimit{}
const complexityExtension = "ComplexityLimit"
type ComplexityStats struct {
// The calculated complexity for this request
Complexity int
// The complexity limit for this request returned by the extension func
ComplexityLimit int
}
// FixedComplexityLimit sets a complexity limit that does not change
func FixedComplexityLimit(limit int) *ComplexityLimit {
return &ComplexityLimit{
Func: func(ctx context.Context, rc *graphql.OperationContext) int {
return limit
},
}
}
func (c ComplexityLimit) ExtensionName() string {
return complexityExtension
}
func (c *ComplexityLimit) Validate(schema graphql.ExecutableSchema) error {
if c.Func == nil {
return fmt.Errorf("ComplexityLimit func can not be nil")
}
c.es = schema
return nil
}
func (c ComplexityLimit) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
op := rc.Doc.Operations.ForName(rc.OperationName)
complexityCalcs := complexity.Calculate(c.es, op, rc.Variables)
limit := c.Func(ctx, rc)
rc.Stats.SetExtension(complexityExtension, &ComplexityStats{
Complexity: complexityCalcs,
ComplexityLimit: limit,
})
if complexityCalcs > limit {
err := gqlerror.Errorf("operation has complexity %d, which exceeds the limit of %d", complexityCalcs, limit)
errcode.Set(err, errComplexityLimit)
return err
}
return nil
}
func GetComplexityStats(ctx context.Context) *ComplexityStats {
rc := graphql.GetOperationContext(ctx)
if rc == nil {
return nil
}
s, _ := rc.Stats.GetExtension(complexityExtension).(*ComplexityStats)
return s
}

View File

@@ -0,0 +1,29 @@
package extension
import (
"context"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// EnableIntrospection enables clients to reflect all of the types available on the graph.
type Introspection struct{}
var _ interface {
graphql.OperationContextMutator
graphql.HandlerExtension
} = Introspection{}
func (c Introspection) ExtensionName() string {
return "Introspection"
}
func (c Introspection) Validate(schema graphql.ExecutableSchema) error {
return nil
}
func (c Introspection) MutateOperationContext(ctx context.Context, rc *graphql.OperationContext) *gqlerror.Error {
rc.DisableIntrospection = false
return nil
}

View File

@@ -0,0 +1,32 @@
package lru
import (
"context"
"github.com/99designs/gqlgen/graphql"
lru "github.com/hashicorp/golang-lru/v2"
)
type LRU struct {
lru *lru.Cache[string, any]
}
var _ graphql.Cache = &LRU{}
func New(size int) *LRU {
cache, err := lru.New[string, any](size)
if err != nil {
// An error is only returned for non-positive cache size
// and we already checked for that.
panic("unexpected error creating cache: " + err.Error())
}
return &LRU{cache}
}
func (l LRU) Get(ctx context.Context, key string) (value interface{}, ok bool) {
return l.lru.Get(key)
}
func (l LRU) Add(ctx context.Context, key string, value interface{}) {
l.lru.Add(key, value)
}

View File

@@ -0,0 +1,186 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/executor"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/99designs/gqlgen/graphql/handler/lru"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
Server struct {
transports []graphql.Transport
exec *executor.Executor
}
)
func New(es graphql.ExecutableSchema) *Server {
return &Server{
exec: executor.New(es),
}
}
func NewDefaultServer(es graphql.ExecutableSchema) *Server {
srv := New(es)
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
srv.AddTransport(transport.POST{})
srv.AddTransport(transport.MultipartForm{})
srv.SetQueryCache(lru.New(1000))
srv.Use(extension.Introspection{})
srv.Use(extension.AutomaticPersistedQuery{
Cache: lru.New(100),
})
return srv
}
func (s *Server) AddTransport(transport graphql.Transport) {
s.transports = append(s.transports, transport)
}
func (s *Server) SetErrorPresenter(f graphql.ErrorPresenterFunc) {
s.exec.SetErrorPresenter(f)
}
func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
s.exec.SetRecoverFunc(f)
}
func (s *Server) SetQueryCache(cache graphql.Cache) {
s.exec.SetQueryCache(cache)
}
func (s *Server) Use(extension graphql.HandlerExtension) {
s.exec.Use(extension)
}
// AroundFields is a convenience method for creating an extension that only implements field middleware
func (s *Server) AroundFields(f graphql.FieldMiddleware) {
s.exec.AroundFields(f)
}
// AroundRootFields is a convenience method for creating an extension that only implements field middleware
func (s *Server) AroundRootFields(f graphql.RootFieldMiddleware) {
s.exec.AroundRootFields(f)
}
// AroundOperations is a convenience method for creating an extension that only implements operation middleware
func (s *Server) AroundOperations(f graphql.OperationMiddleware) {
s.exec.AroundOperations(f)
}
// AroundResponses is a convenience method for creating an extension that only implements response middleware
func (s *Server) AroundResponses(f graphql.ResponseMiddleware) {
s.exec.AroundResponses(f)
}
func (s *Server) getTransport(r *http.Request) graphql.Transport {
for _, t := range s.transports {
if t.Supports(r) {
return t
}
}
return nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
err := s.exec.PresentRecoveredError(r.Context(), err)
gqlErr, _ := err.(*gqlerror.Error)
resp := &graphql.Response{Errors: []*gqlerror.Error{gqlErr}}
b, _ := json.Marshal(resp)
w.WriteHeader(http.StatusUnprocessableEntity)
w.Write(b)
}
}()
r = r.WithContext(graphql.StartOperationTrace(r.Context()))
transport := s.getTransport(r)
if transport == nil {
sendErrorf(w, http.StatusBadRequest, "transport not supported")
return
}
transport.Do(w, r, s.exec)
}
func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
if err != nil {
panic(err)
}
w.Write(b)
}
func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}
type OperationFunc func(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler
func (r OperationFunc) ExtensionName() string {
return "InlineOperationFunc"
}
func (r OperationFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("OperationFunc can not be nil")
}
return nil
}
func (r OperationFunc) InterceptOperation(ctx context.Context, next graphql.OperationHandler) graphql.ResponseHandler {
return r(ctx, next)
}
type ResponseFunc func(ctx context.Context, next graphql.ResponseHandler) *graphql.Response
func (r ResponseFunc) ExtensionName() string {
return "InlineResponseFunc"
}
func (r ResponseFunc) Validate(schema graphql.ExecutableSchema) error {
if r == nil {
return fmt.Errorf("ResponseFunc can not be nil")
}
return nil
}
func (r ResponseFunc) InterceptResponse(ctx context.Context, next graphql.ResponseHandler) *graphql.Response {
return r(ctx, next)
}
type FieldFunc func(ctx context.Context, next graphql.Resolver) (res interface{}, err error)
func (f FieldFunc) ExtensionName() string {
return "InlineFieldFunc"
}
func (f FieldFunc) Validate(schema graphql.ExecutableSchema) error {
if f == nil {
return fmt.Errorf("FieldFunc can not be nil")
}
return nil
}
func (f FieldFunc) InterceptField(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return f(ctx, next)
}

View File

@@ -0,0 +1,26 @@
package transport
import (
"encoding/json"
"fmt"
"net/http"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// SendError sends a best effort error to a raw response writer. It assumes the client can understand the standard
// json error response
func SendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) {
w.WriteHeader(code)
b, err := json.Marshal(&graphql.Response{Errors: errors})
if err != nil {
panic(err)
}
w.Write(b)
}
// SendErrorf wraps SendError to add formatted messages
func SendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
SendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)})
}

View File

@@ -0,0 +1,17 @@
package transport
import "net/http"
func writeHeaders(w http.ResponseWriter, headers map[string][]string) {
if len(headers) == 0 {
headers = map[string][]string{
"Content-Type": {"application/json"},
}
}
for key, values := range headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
}

View File

@@ -0,0 +1,222 @@
package transport
import (
"encoding/json"
"io"
"mime"
"net/http"
"os"
"github.com/99designs/gqlgen/graphql"
)
// MultipartForm the Multipart request spec https://github.com/jaydenseric/graphql-multipart-request-spec
type MultipartForm struct {
// MaxUploadSize sets the maximum number of bytes used to parse a request body
// as multipart/form-data.
MaxUploadSize int64
// MaxMemory defines the maximum number of bytes used to parse a request body
// as multipart/form-data in memory, with the remainder stored on disk in
// temporary files.
MaxMemory int64
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = MultipartForm{}
func (f MultipartForm) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "multipart/form-data"
}
func (f MultipartForm) maxUploadSize() int64 {
if f.MaxUploadSize == 0 {
return 32 << 20
}
return f.MaxUploadSize
}
func (f MultipartForm) maxMemory() int64 {
if f.MaxMemory == 0 {
return 32 << 20
}
return f.MaxMemory
}
func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
writeHeaders(w, f.ResponseHeaders)
start := graphql.Now()
var err error
if r.ContentLength > f.maxUploadSize() {
writeJsonError(w, "failed to parse multipart form, request body too large")
return
}
r.Body = http.MaxBytesReader(w, r.Body, f.maxUploadSize())
defer r.Body.Close()
mr, err := r.MultipartReader()
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "failed to parse multipart form")
return
}
part, err := mr.NextPart()
if err != nil || part.FormName() != "operations" {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "first part must be operations")
return
}
var params graphql.RawParams
if err = jsonDecode(part, &params); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "operations form field could not be decoded")
return
}
part, err = mr.NextPart()
if err != nil || part.FormName() != "map" {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "second part must be map")
return
}
uploadsMap := map[string][]string{}
if err = json.NewDecoder(part).Decode(&uploadsMap); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "map form field could not be decoded")
return
}
for {
part, err = mr.NextPart()
if err == io.EOF {
break
} else if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to parse part")
return
}
key := part.FormName()
filename := part.FileName()
contentType := part.Header.Get("Content-Type")
paths := uploadsMap[key]
if len(paths) == 0 {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "invalid empty operations paths list for key %s", key)
return
}
delete(uploadsMap, key)
var upload graphql.Upload
if r.ContentLength < f.maxMemory() {
fileBytes, err := io.ReadAll(part)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to read file for key %s", key)
return
}
for _, path := range paths {
upload = graphql.Upload{
File: &bytesReader{s: &fileBytes, i: 0},
Size: int64(len(fileBytes)),
Filename: filename,
ContentType: contentType,
}
if err := params.AddUpload(upload, key, path); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
}
} else {
tmpFile, err := os.CreateTemp(os.TempDir(), "gqlgen-")
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to create temp file for key %s", key)
return
}
tmpName := tmpFile.Name()
defer func() {
_ = os.Remove(tmpName)
}()
fileSize, err := io.Copy(tmpFile, part)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
if err := tmpFile.Close(); err != nil {
writeJsonErrorf(w, "failed to copy to temp file and close temp file for key %s", key)
return
}
writeJsonErrorf(w, "failed to copy to temp file for key %s", key)
return
}
if err := tmpFile.Close(); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to close temp file for key %s", key)
return
}
for _, path := range paths {
pathTmpFile, err := os.Open(tmpName)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to open temp file for key %s", key)
return
}
defer pathTmpFile.Close()
upload = graphql.Upload{
File: pathTmpFile,
Size: fileSize,
Filename: filename,
ContentType: contentType,
}
if err := params.AddUpload(upload, key, path); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
}
}
}
for key := range uploadsMap {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to get key %s from form", key)
return
}
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, gerr := exec.CreateOperationContext(r.Context(), &params)
if gerr != nil {
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), gerr)
w.WriteHeader(statusFor(gerr))
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
}

View File

@@ -0,0 +1,119 @@
package transport
import (
"io"
"mime"
"net/http"
"net/url"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// FORM implements the application/x-www-form-urlencoded side of the default HTTP transport
type UrlEncodedForm struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = UrlEncodedForm{}
func (h UrlEncodedForm) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/x-www-form-urlencoded"
}
func (h UrlEncodedForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf("could not get form body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
params, err = h.parseBody(bodyString)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
gqlErr := gqlerror.Errorf("could not cleanup body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}
func (h UrlEncodedForm) parseBody(bodyString string) (*graphql.RawParams, error) {
switch {
case strings.Contains(bodyString, "\"query\":"):
// body is json
return h.parseJson(bodyString)
case strings.HasPrefix(bodyString, "query=%7B"):
// body is urlencoded
return h.parseEncoded(bodyString)
default:
// body is plain text
params := &graphql.RawParams{}
params.Query = strings.TrimPrefix(bodyString, "query=")
return params, nil
}
}
func (h UrlEncodedForm) parseEncoded(bodyString string) (*graphql.RawParams, error) {
params := &graphql.RawParams{}
query, err := url.QueryUnescape(bodyString)
if err != nil {
return nil, err
}
params.Query = strings.TrimPrefix(query, "query=")
return params, nil
}
func (h UrlEncodedForm) parseJson(bodyString string) (*graphql.RawParams, error) {
params := &graphql.RawParams{}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
err := jsonDecode(bodyReader, &params)
if err != nil {
return nil, err
}
return params, nil
}

View File

@@ -0,0 +1,99 @@
package transport
import (
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// GET implements the GET side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#get
type GET struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = GET{}
func (h GET) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
return r.Method == "GET"
}
func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
query, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, err.Error())
return
}
writeHeaders(w, h.ResponseHeaders)
raw := &graphql.RawParams{
Query: query.Get("query"),
OperationName: query.Get("operationName"),
Headers: r.Header,
}
raw.ReadTime.Start = graphql.Now()
if variables := query.Get("variables"); variables != "" {
if err := jsonDecode(strings.NewReader(variables), &raw.Variables); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "variables could not be decoded")
return
}
}
if extensions := query.Get("extensions"); extensions != "" {
if err := jsonDecode(strings.NewReader(extensions), &raw.Extensions); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "extensions could not be decoded")
return
}
}
raw.ReadTime.End = graphql.Now()
rc, gqlError := exec.CreateOperationContext(r.Context(), raw)
if gqlError != nil {
w.WriteHeader(statusFor(gqlError))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), gqlError)
writeJson(w, resp)
return
}
op := rc.Doc.Operations.ForName(rc.OperationName)
if op.Operation != ast.Query {
w.WriteHeader(http.StatusNotAcceptable)
writeJsonError(w, "GET requests only allow query operations")
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
writeJson(w, responses(ctx))
}
func jsonDecode(r io.Reader, val interface{}) error {
dec := json.NewDecoder(r)
dec.UseNumber()
return dec.Decode(val)
}
func statusFor(errs gqlerror.List) int {
switch errcode.GetErrorKind(errs) {
case errcode.KindProtocol:
return http.StatusUnprocessableEntity
default:
return http.StatusOK
}
}

View File

@@ -0,0 +1,98 @@
package transport
import (
"mime"
"net/http"
"net/url"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// GRAPHQL implements the application/graphql side of the HTTP transport
// see: https://graphql.org/learn/serving-over-http/#post-request
// If the "application/graphql" Content-Type header is present, treat
// the HTTP POST body contents as the GraphQL query string.
type GRAPHQL struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = GRAPHQL{}
func (h GRAPHQL) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/graphql"
}
func (h GRAPHQL) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
params.Query, err = cleanupBody(bodyString)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
gqlErr := gqlerror.Errorf("could not cleanup body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}
// Makes sure we strip "query=" keyword from body and
// that body is not url escaped
func cleanupBody(body string) (out string, err error) {
// Some clients send 'query=' at the start of body payload. Let's remove
// it to get GQL query only.
body = strings.TrimPrefix(body, "query=")
// Body payload can be url encoded or not. We check if %7B - "{" character
// is where query starts. If it is, query is url encoded.
if strings.HasPrefix(body, "%7B") {
body, err = url.QueryUnescape(body)
if err != nil {
return body, err
}
}
return body, err
}

View File

@@ -0,0 +1,92 @@
package transport
import (
"fmt"
"io"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
// POST implements the POST side of the default HTTP transport
// defined in https://github.com/APIs-guru/graphql-over-http#post
type POST struct {
// Map of all headers that are added to graphql response. If not
// set, only one header: Content-Type: application/json will be set.
ResponseHeaders map[string][]string
}
var _ graphql.Transport = POST{}
func (h POST) Supports(r *http.Request) bool {
if r.Header.Get("Upgrade") != "" {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == "POST" && mediaType == "application/json"
}
func getRequestBody(r *http.Request) (string, error) {
if r == nil || r.Body == nil {
return "", nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return "", fmt.Errorf("unable to get Request Body %w", err)
}
return string(body), nil
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
writeHeaders(w, h.ResponseHeaders)
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}

View File

@@ -0,0 +1,37 @@
package transport
import (
"net/http"
"strings"
"github.com/99designs/gqlgen/graphql"
)
// Options responds to http OPTIONS and HEAD requests
type Options struct {
// AllowedMethods is a list of allowed HTTP methods.
AllowedMethods []string
}
var _ graphql.Transport = Options{}
func (o Options) Supports(r *http.Request) bool {
return r.Method == "HEAD" || r.Method == "OPTIONS"
}
func (o Options) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
switch r.Method {
case http.MethodOptions:
w.Header().Set("Allow", o.allowedMethods())
w.WriteHeader(http.StatusOK)
case http.MethodHead:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
func (o Options) allowedMethods() string {
if len(o.AllowedMethods) == 0 {
return "OPTIONS, GET, POST"
}
return strings.Join(o.AllowedMethods, ", ")
}

View File

@@ -0,0 +1,45 @@
package transport
import (
"errors"
"io"
)
type bytesReader struct {
s *[]byte
i int64 // current reading index
}
func (r *bytesReader) Read(b []byte) (n int, err error) {
if r.s == nil {
return 0, errors.New("byte slice pointer is nil")
}
if r.i >= int64(len(*r.s)) {
return 0, io.EOF
}
n = copy(b, (*r.s)[r.i:])
r.i += int64(n)
return
}
func (r *bytesReader) Seek(offset int64, whence int) (int64, error) {
if r.s == nil {
return 0, errors.New("byte slice pointer is nil")
}
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = r.i + offset
case io.SeekEnd:
abs = int64(len(*r.s)) + offset
default:
return 0, errors.New("invalid whence")
}
if abs < 0 {
return 0, errors.New("negative position")
}
r.i = abs
return abs, nil
}

View File

@@ -0,0 +1,107 @@
package transport
import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
type SSE struct{}
var _ graphql.Transport = SSE{}
func (t SSE) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}
func (t SSE) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "application/json")
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}
rc, opErr := exec.CreateOperationContext(ctx, params)
ctx = graphql.WithOperationContext(ctx, rc)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprint(w, ":\n\n")
flusher.Flush()
if opErr != nil {
resp := exec.DispatchError(ctx, opErr)
writeJsonWithSSE(w, resp)
} else {
responses, ctx := exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
writeJsonWithSSE(w, response)
flusher.Flush()
}
}
fmt.Fprint(w, "event: complete\n\n")
}
func writeJsonWithSSE(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
fmt.Fprintf(w, "event: next\ndata: %s\n\n", b)
}

View File

@@ -0,0 +1,30 @@
package transport
import (
"encoding/json"
"fmt"
"io"
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)
func writeJson(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
w.Write(b)
}
func writeJsonError(w io.Writer, msg string) {
writeJson(w, &graphql.Response{Errors: gqlerror.List{{Message: msg}}})
}
func writeJsonErrorf(w io.Writer, format string, args ...interface{}) {
writeJson(w, &graphql.Response{Errors: gqlerror.List{{Message: fmt.Sprintf(format, args...)}}})
}
func writeJsonGraphqlError(w io.Writer, err ...*gqlerror.Error) {
writeJson(w, &graphql.Response{Errors: err})
}

View File

@@ -0,0 +1,444 @@
package transport
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/errcode"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser/v2/gqlerror"
)
type (
Websocket struct {
Upgrader websocket.Upgrader
InitFunc WebsocketInitFunc
InitTimeout time.Duration
ErrorFunc WebsocketErrorFunc
CloseFunc WebsocketCloseFunc
KeepAlivePingInterval time.Duration
PingPongInterval time.Duration
didInjectSubprotocols bool
}
wsConnection struct {
Websocket
ctx context.Context
conn *websocket.Conn
me messageExchanger
active map[string]context.CancelFunc
mu sync.Mutex
keepAliveTicker *time.Ticker
pingPongTicker *time.Ticker
exec graphql.GraphExecutor
initPayload InitPayload
}
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketErrorFunc func(ctx context.Context, err error)
// Callback called when websocket is closed.
WebsocketCloseFunc func(ctx context.Context, closeCode int)
)
var errReadTimeout = errors.New("read timeout")
type WebsocketError struct {
Err error
// IsReadError flags whether the error occurred on read or write to the websocket
IsReadError bool
}
func (e WebsocketError) Error() string {
if e.IsReadError {
return fmt.Sprintf("websocket read: %v", e.Err)
}
return fmt.Sprintf("websocket write: %v", e.Err)
}
var (
_ graphql.Transport = Websocket{}
_ error = WebsocketError{}
)
func (t Websocket) Supports(r *http.Request) bool {
return r.Header.Get("Upgrade") != ""
}
func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
t.injectGraphQLWSSubprotocols()
ws, err := t.Upgrader.Upgrade(w, r, http.Header{})
if err != nil {
log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
SendErrorf(w, http.StatusBadRequest, "unable to upgrade")
return
}
var me messageExchanger
switch ws.Subprotocol() {
default:
msg := websocket.FormatCloseMessage(websocket.CloseProtocolError, fmt.Sprintf("unsupported negotiated subprotocol %s", ws.Subprotocol()))
ws.WriteMessage(websocket.CloseMessage, msg)
return
case graphqlwsSubprotocol, "":
// clients are required to send a subprotocol, to be backward compatible with the previous implementation we select
// "graphql-ws" by default
me = graphqlwsMessageExchanger{c: ws}
case graphqltransportwsSubprotocol:
me = graphqltransportwsMessageExchanger{c: ws}
}
conn := wsConnection{
active: map[string]context.CancelFunc{},
conn: ws,
ctx: r.Context(),
exec: exec,
me: me,
Websocket: t,
}
if !conn.init() {
return
}
conn.run()
}
func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
if c.ErrorFunc != nil && err != nil {
c.ErrorFunc(c.ctx, WebsocketError{
Err: err,
IsReadError: isReadError,
})
}
}
func (c *wsConnection) nextMessageWithTimeout(timeout time.Duration) (message, error) {
messages, errs := make(chan message, 1), make(chan error, 1)
go func() {
if m, err := c.me.NextMessage(); err != nil {
errs <- err
} else {
messages <- m
}
}()
select {
case m := <-messages:
return m, nil
case err := <-errs:
return message{}, err
case <-time.After(timeout):
return message{}, errReadTimeout
}
}
func (c *wsConnection) init() bool {
var m message
var err error
if c.InitTimeout != 0 {
m, err = c.nextMessageWithTimeout(c.InitTimeout)
} else {
m, err = c.me.NextMessage()
}
if err != nil {
if err == errReadTimeout {
c.close(websocket.CloseProtocolError, "connection initialisation timeout")
return false
}
if err == errInvalidMsg {
c.sendConnectionError("invalid json")
}
c.close(websocket.CloseProtocolError, "decoding error")
return false
}
switch m.t {
case initMessageType:
if len(m.payload) > 0 {
c.initPayload = make(InitPayload)
err := json.Unmarshal(m.payload, &c.initPayload)
if err != nil {
return false
}
}
if c.InitFunc != nil {
ctx, err := c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
c.ctx = ctx
}
c.write(&message{t: connectionAckMessageType})
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return false
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return false
}
return true
}
func (c *wsConnection) write(msg *message) {
c.mu.Lock()
c.handlePossibleError(c.me.Send(msg), false)
c.mu.Unlock()
}
func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer func() {
cancel()
c.close(websocket.CloseAbnormalClosure, "unexpected closure")
}()
// If we're running in graphql-ws mode, create a timer that will trigger a
// keep alive message every interval
if (c.conn.Subprotocol() == "" || c.conn.Subprotocol() == graphqlwsSubprotocol) && c.KeepAlivePingInterval != 0 {
c.mu.Lock()
c.keepAliveTicker = time.NewTicker(c.KeepAlivePingInterval)
c.mu.Unlock()
go c.keepAlive(ctx)
}
// If we're running in graphql-transport-ws mode, create a timer that will
// trigger a ping message every interval
if c.conn.Subprotocol() == graphqltransportwsSubprotocol && c.PingPongInterval != 0 {
c.mu.Lock()
c.pingPongTicker = time.NewTicker(c.PingPongInterval)
c.mu.Unlock()
// Note: when the connection is closed by this deadline, the client
// will receive an "invalid close code"
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
go c.ping(ctx)
}
// Close the connection when the context is cancelled.
// Will optionally send a "close reason" that is retrieved from the context.
go c.closeOnCancel(ctx)
for {
start := graphql.Now()
m, err := c.me.NextMessage()
if err != nil {
// If the connection got closed by us, don't report the error
if !errors.Is(err, net.ErrClosed) {
c.handlePossibleError(err, true)
}
return
}
switch m.t {
case startMessageType:
c.subscribe(start, &m)
case stopMessageType:
c.mu.Lock()
closer := c.active[m.id]
c.mu.Unlock()
if closer != nil {
closer()
}
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
return
case pingMessageType:
c.write(&message{t: pongMessageType, payload: m.payload})
case pongMessageType:
c.conn.SetReadDeadline(time.Now().UTC().Add(2 * c.PingPongInterval))
default:
c.sendConnectionError("unexpected message %s", m.t)
c.close(websocket.CloseProtocolError, "unexpected message")
return
}
}
}
func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.keepAliveTicker.Stop()
return
case <-c.keepAliveTicker.C:
c.write(&message{t: keepAliveMessageType})
}
}
}
func (c *wsConnection) ping(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.pingPongTicker.Stop()
return
case <-c.pingPongTicker.C:
c.write(&message{t: pingMessageType, payload: json.RawMessage{}})
}
}
}
func (c *wsConnection) closeOnCancel(ctx context.Context) {
<-ctx.Done()
if r := closeReasonForContext(ctx); r != "" {
c.sendConnectionError(r)
}
c.close(websocket.CloseNormalClosure, "terminated")
}
func (c *wsConnection) subscribe(start time.Time, msg *message) {
ctx := graphql.StartOperationTrace(c.ctx)
var params *graphql.RawParams
if err := jsonDecode(bytes.NewReader(msg.payload), &params); err != nil {
c.sendError(msg.id, &gqlerror.Error{Message: "invalid json"})
c.complete(msg.id)
return
}
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, err := c.exec.CreateOperationContext(ctx, params)
if err != nil {
resp := c.exec.DispatchError(graphql.WithOperationContext(ctx, rc), err)
switch errcode.GetErrorKind(err) {
case errcode.KindProtocol:
c.sendError(msg.id, resp.Errors...)
default:
c.sendResponse(msg.id, &graphql.Response{Errors: err})
}
c.complete(msg.id)
return
}
ctx = graphql.WithOperationContext(ctx, rc)
if c.initPayload != nil {
ctx = withInitPayload(ctx, c.initPayload)
}
ctx, cancel := context.WithCancel(ctx)
c.mu.Lock()
c.active[msg.id] = cancel
c.mu.Unlock()
go func() {
ctx = withSubscriptionErrorContext(ctx)
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
var gqlerr *gqlerror.Error
if !errors.As(err, &gqlerr) {
gqlerr = &gqlerror.Error{}
if err != nil {
gqlerr.Message = err.Error()
}
}
c.sendError(msg.id, gqlerr)
}
if errs := getSubscriptionError(ctx); len(errs) != 0 {
c.sendError(msg.id, errs...)
} else {
c.complete(msg.id)
}
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
cancel()
}()
responses, ctx := c.exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
c.sendResponse(msg.id, response)
}
// complete and context cancel comes from the defer
}()
}
func (c *wsConnection) sendResponse(id string, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
c.write(&message{
payload: b,
id: id,
t: dataMessageType,
})
}
func (c *wsConnection) complete(id string) {
c.write(&message{id: id, t: completeMessageType})
}
func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
errs := make([]error, len(errors))
for i, err := range errors {
errs[i] = err
}
b, err := json.Marshal(errs)
if err != nil {
panic(err)
}
c.write(&message{t: errorMessageType, id: id, payload: b})
}
func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
if err != nil {
panic(err)
}
c.write(&message{t: connectionErrorMessageType, payload: b})
}
func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
for _, closer := range c.active {
closer()
}
c.mu.Unlock()
_ = c.conn.Close()
if c.CloseFunc != nil {
c.CloseFunc(c.ctx, closeCode)
}
}

View File

@@ -0,0 +1,22 @@
package transport
import (
"context"
)
// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var closeReasonCtxKey = &wsCloseReasonContextKey{"close-reason"}
type wsCloseReasonContextKey struct {
name string
}
func AppendCloseReason(ctx context.Context, reason string) context.Context {
return context.WithValue(ctx, closeReasonCtxKey, reason)
}
func closeReasonForContext(ctx context.Context) string {
reason, _ := ctx.Value(closeReasonCtxKey).(string)
return reason
}

View File

@@ -0,0 +1,149 @@
package transport
import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
const (
graphqltransportwsSubprotocol = "graphql-transport-ws"
graphqltransportwsConnectionInitMsg = graphqltransportwsMessageType("connection_init")
graphqltransportwsConnectionAckMsg = graphqltransportwsMessageType("connection_ack")
graphqltransportwsSubscribeMsg = graphqltransportwsMessageType("subscribe")
graphqltransportwsNextMsg = graphqltransportwsMessageType("next")
graphqltransportwsErrorMsg = graphqltransportwsMessageType("error")
graphqltransportwsCompleteMsg = graphqltransportwsMessageType("complete")
graphqltransportwsPingMsg = graphqltransportwsMessageType("ping")
graphqltransportwsPongMsg = graphqltransportwsMessageType("pong")
)
var allGraphqltransportwsMessageTypes = []graphqltransportwsMessageType{
graphqltransportwsConnectionInitMsg,
graphqltransportwsConnectionAckMsg,
graphqltransportwsSubscribeMsg,
graphqltransportwsNextMsg,
graphqltransportwsErrorMsg,
graphqltransportwsCompleteMsg,
graphqltransportwsPingMsg,
graphqltransportwsPongMsg,
}
type (
graphqltransportwsMessageExchanger struct {
c *websocket.Conn
}
graphqltransportwsMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type graphqltransportwsMessageType `json:"type"`
noOp bool
}
graphqltransportwsMessageType string
)
func (me graphqltransportwsMessageExchanger) NextMessage() (message, error) {
_, r, err := me.c.NextReader()
if err != nil {
return message{}, handleNextReaderError(err)
}
var graphqltransportwsMessage graphqltransportwsMessage
if err := jsonDecode(r, &graphqltransportwsMessage); err != nil {
return message{}, errInvalidMsg
}
return graphqltransportwsMessage.toMessage()
}
func (me graphqltransportwsMessageExchanger) Send(m *message) error {
msg := &graphqltransportwsMessage{}
if err := msg.fromMessage(m); err != nil {
return err
}
if msg.noOp {
return nil
}
return me.c.WriteJSON(msg)
}
func (t *graphqltransportwsMessageType) UnmarshalText(text []byte) (err error) {
var found bool
for _, candidate := range allGraphqltransportwsMessageTypes {
if string(candidate) == string(text) {
*t = candidate
found = true
break
}
}
if !found {
err = fmt.Errorf("invalid message type %s", string(text))
}
return err
}
func (t graphqltransportwsMessageType) MarshalText() ([]byte, error) {
return []byte(string(t)), nil
}
func (m graphqltransportwsMessage) toMessage() (message, error) {
var t messageType
var err error
switch m.Type {
default:
err = fmt.Errorf("invalid client->server message type %s", m.Type)
case graphqltransportwsConnectionInitMsg:
t = initMessageType
case graphqltransportwsSubscribeMsg:
t = startMessageType
case graphqltransportwsCompleteMsg:
t = stopMessageType
case graphqltransportwsPingMsg:
t = pingMessageType
case graphqltransportwsPongMsg:
t = pongMessageType
}
return message{
payload: m.Payload,
id: m.ID,
t: t,
}, err
}
func (m *graphqltransportwsMessage) fromMessage(msg *message) (err error) {
m.ID = msg.id
m.Payload = msg.payload
switch msg.t {
default:
err = fmt.Errorf("invalid server->client message type %s", msg.t)
case connectionAckMessageType:
m.Type = graphqltransportwsConnectionAckMsg
case keepAliveMessageType:
m.noOp = true
case connectionErrorMessageType:
m.noOp = true
case dataMessageType:
m.Type = graphqltransportwsNextMsg
case completeMessageType:
m.Type = graphqltransportwsCompleteMsg
case errorMessageType:
m.Type = graphqltransportwsErrorMsg
case pingMessageType:
m.Type = graphqltransportwsPingMsg
case pongMessageType:
m.Type = graphqltransportwsPongMsg
}
return err
}

View File

@@ -0,0 +1,171 @@
package transport
import (
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
const (
graphqlwsSubprotocol = "graphql-ws"
graphqlwsConnectionInitMsg = graphqlwsMessageType("connection_init")
graphqlwsConnectionTerminateMsg = graphqlwsMessageType("connection_terminate")
graphqlwsStartMsg = graphqlwsMessageType("start")
graphqlwsStopMsg = graphqlwsMessageType("stop")
graphqlwsConnectionAckMsg = graphqlwsMessageType("connection_ack")
graphqlwsConnectionErrorMsg = graphqlwsMessageType("connection_error")
graphqlwsDataMsg = graphqlwsMessageType("data")
graphqlwsErrorMsg = graphqlwsMessageType("error")
graphqlwsCompleteMsg = graphqlwsMessageType("complete")
graphqlwsConnectionKeepAliveMsg = graphqlwsMessageType("ka")
)
var allGraphqlwsMessageTypes = []graphqlwsMessageType{
graphqlwsConnectionInitMsg,
graphqlwsConnectionTerminateMsg,
graphqlwsStartMsg,
graphqlwsStopMsg,
graphqlwsConnectionAckMsg,
graphqlwsConnectionErrorMsg,
graphqlwsDataMsg,
graphqlwsErrorMsg,
graphqlwsCompleteMsg,
graphqlwsConnectionKeepAliveMsg,
}
type (
graphqlwsMessageExchanger struct {
c *websocket.Conn
}
graphqlwsMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type graphqlwsMessageType `json:"type"`
noOp bool
}
graphqlwsMessageType string
)
func (me graphqlwsMessageExchanger) NextMessage() (message, error) {
_, r, err := me.c.NextReader()
if err != nil {
return message{}, handleNextReaderError(err)
}
var graphqlwsMessage graphqlwsMessage
if err := jsonDecode(r, &graphqlwsMessage); err != nil {
return message{}, errInvalidMsg
}
return graphqlwsMessage.toMessage()
}
func (me graphqlwsMessageExchanger) Send(m *message) error {
msg := &graphqlwsMessage{}
if err := msg.fromMessage(m); err != nil {
return err
}
if msg.noOp {
return nil
}
return me.c.WriteJSON(msg)
}
func (t *graphqlwsMessageType) UnmarshalText(text []byte) (err error) {
var found bool
for _, candidate := range allGraphqlwsMessageTypes {
if string(candidate) == string(text) {
*t = candidate
found = true
break
}
}
if !found {
err = fmt.Errorf("invalid message type %s", string(text))
}
return err
}
func (t graphqlwsMessageType) MarshalText() ([]byte, error) {
return []byte(string(t)), nil
}
func (m graphqlwsMessage) toMessage() (message, error) {
var t messageType
var err error
switch m.Type {
default:
err = fmt.Errorf("invalid client->server message type %s", m.Type)
case graphqlwsConnectionInitMsg:
t = initMessageType
case graphqlwsConnectionTerminateMsg:
t = connectionCloseMessageType
case graphqlwsStartMsg:
t = startMessageType
case graphqlwsStopMsg:
t = stopMessageType
case graphqlwsConnectionAckMsg:
t = connectionAckMessageType
case graphqlwsConnectionErrorMsg:
t = connectionErrorMessageType
case graphqlwsDataMsg:
t = dataMessageType
case graphqlwsErrorMsg:
t = errorMessageType
case graphqlwsCompleteMsg:
t = completeMessageType
case graphqlwsConnectionKeepAliveMsg:
t = keepAliveMessageType
}
return message{
payload: m.Payload,
id: m.ID,
t: t,
}, err
}
func (m *graphqlwsMessage) fromMessage(msg *message) (err error) {
m.ID = msg.id
m.Payload = msg.payload
switch msg.t {
default:
err = fmt.Errorf("invalid server->client message type %s", msg.t)
case initMessageType:
m.Type = graphqlwsConnectionInitMsg
case connectionAckMessageType:
m.Type = graphqlwsConnectionAckMsg
case keepAliveMessageType:
m.Type = graphqlwsConnectionKeepAliveMsg
case connectionErrorMessageType:
m.Type = graphqlwsConnectionErrorMsg
case connectionCloseMessageType:
m.Type = graphqlwsConnectionTerminateMsg
case startMessageType:
m.Type = graphqlwsStartMsg
case stopMessageType:
m.Type = graphqlwsStopMsg
case dataMessageType:
m.Type = graphqlwsDataMsg
case completeMessageType:
m.Type = graphqlwsCompleteMsg
case errorMessageType:
m.Type = graphqlwsErrorMsg
case pingMessageType:
m.noOp = true
case pongMessageType:
m.noOp = true
}
return err
}

View File

@@ -0,0 +1,57 @@
package transport
import "context"
type key string
const (
initpayload key = "ws_initpayload_context"
)
// InitPayload is a structure that is parsed from the websocket init message payload. TO use
// request headers for non-websocket, instead wrap the graphql handler in a middleware.
type InitPayload map[string]interface{}
// GetString safely gets a string value from the payload. It returns an empty string if the
// payload is nil or the value isn't set.
func (p InitPayload) GetString(key string) string {
if p == nil {
return ""
}
if value, ok := p[key]; ok {
res, _ := value.(string)
return res
}
return ""
}
// Authorization is a short hand for getting the Authorization header from the
// payload.
func (p InitPayload) Authorization() string {
if value := p.GetString("Authorization"); value != "" {
return value
}
if value := p.GetString("authorization"); value != "" {
return value
}
return ""
}
func withInitPayload(ctx context.Context, payload InitPayload) context.Context {
return context.WithValue(ctx, initpayload, payload)
}
// GetInitPayload gets a map of the data sent with the connection_init message, which is used by
// graphql clients as a stand-in for HTTP headers.
func GetInitPayload(ctx context.Context) InitPayload {
payload, ok := ctx.Value(initpayload).(InitPayload)
if !ok {
return nil
}
return payload
}

View File

@@ -0,0 +1,69 @@
package transport
import (
"context"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var wsSubscriptionErrorCtxKey = &wsSubscriptionErrorContextKey{"subscription-error"}
type wsSubscriptionErrorContextKey struct {
name string
}
type subscriptionError struct {
errs []*gqlerror.Error
}
// AddSubscriptionError is used to let websocket return an error message after subscription resolver returns a channel.
// for example:
//
// func (r *subscriptionResolver) Method(ctx context.Context) (<-chan *model.Message, error) {
// ch := make(chan *model.Message)
// go func() {
// defer func() {
// close(ch)
// }
// // some kind of block processing (e.g.: gRPC client streaming)
// stream, err := gRPCClientStreamRequest(ctx)
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// for {
// m, err := stream.Recv()
// if err == io.EOF {
// return
// }
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// ch <- m
// }
// }()
//
// return ch, nil
// }
//
// see https://github.com/99designs/gqlgen/pull/2506 for more details
func AddSubscriptionError(ctx context.Context, err *gqlerror.Error) {
subscriptionErrStruct := getSubscriptionErrorStruct(ctx)
subscriptionErrStruct.errs = append(subscriptionErrStruct.errs, err)
}
func withSubscriptionErrorContext(ctx context.Context) context.Context {
return context.WithValue(ctx, wsSubscriptionErrorCtxKey, &subscriptionError{})
}
func getSubscriptionErrorStruct(ctx context.Context) *subscriptionError {
v, _ := ctx.Value(wsSubscriptionErrorCtxKey).(*subscriptionError)
return v
}
func getSubscriptionError(ctx context.Context) []*gqlerror.Error {
return getSubscriptionErrorStruct(ctx).errs
}

View File

@@ -0,0 +1,116 @@
package transport
import (
"encoding/json"
"errors"
"github.com/gorilla/websocket"
)
const (
initMessageType messageType = iota
connectionAckMessageType
keepAliveMessageType
connectionErrorMessageType
connectionCloseMessageType
startMessageType
stopMessageType
dataMessageType
completeMessageType
errorMessageType
pingMessageType
pongMessageType
)
var (
supportedSubprotocols = []string{
graphqlwsSubprotocol,
graphqltransportwsSubprotocol,
}
errWsConnClosed = errors.New("websocket connection closed")
errInvalidMsg = errors.New("invalid message received")
)
type (
messageType int
message struct {
payload json.RawMessage
id string
t messageType
}
messageExchanger interface {
NextMessage() (message, error)
Send(m *message) error
}
)
func (t messageType) String() string {
var text string
switch t {
default:
text = "unknown"
case initMessageType:
text = "init"
case connectionAckMessageType:
text = "connection ack"
case keepAliveMessageType:
text = "keep alive"
case connectionErrorMessageType:
text = "connection error"
case connectionCloseMessageType:
text = "connection close"
case startMessageType:
text = "start"
case stopMessageType:
text = "stop subscription"
case dataMessageType:
text = "data"
case completeMessageType:
text = "complete"
case errorMessageType:
text = "error"
case pingMessageType:
text = "ping"
case pongMessageType:
text = "pong"
}
return text
}
func contains(list []string, elem string) bool {
for _, e := range list {
if e == elem {
return true
}
}
return false
}
func (t *Websocket) injectGraphQLWSSubprotocols() {
// the list of subprotocols is specified by the consumer of the Websocket struct,
// in order to preserve backward compatibility, we inject the graphql specific subprotocols
// at runtime
if !t.didInjectSubprotocols {
defer func() {
t.didInjectSubprotocols = true
}()
for _, subprotocol := range supportedSubprotocols {
if !contains(t.Upgrader.Subprotocols, subprotocol) {
t.Upgrader.Subprotocols = append(t.Upgrader.Subprotocols, subprotocol)
}
}
}
}
func handleNextReaderError(err error) error {
// TODO: should we consider all closure scenarios here for the ws connection?
// for now we only list the error codes from the previous implementation
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
return errWsConnClosed
}
return err
}

View File

@@ -0,0 +1,58 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
)
func MarshalID(s string) Marshaler {
return MarshalString(s)
}
func UnmarshalID(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case json.Number:
return string(v), nil
case int:
return strconv.Itoa(v), nil
case int64:
return strconv.FormatInt(v, 10), nil
case float64:
return fmt.Sprintf("%f", v), nil
case bool:
if v {
return "true", nil
} else {
return "false", nil
}
case nil:
return "null", nil
default:
return "", fmt.Errorf("%T is not a string", v)
}
}
func MarshalIntID(i int) Marshaler {
return WriterFunc(func(w io.Writer) {
writeQuotedString(w, strconv.Itoa(i))
})
}
func UnmarshalIntID(v interface{}) (int, error) {
switch v := v.(type) {
case string:
return strconv.Atoi(v)
case int:
return v, nil
case int64:
return int(v), nil
case json.Number:
return strconv.Atoi(string(v))
default:
return 0, fmt.Errorf("%T is not an int", v)
}
}

View File

@@ -0,0 +1,55 @@
package graphql
import (
"context"
"errors"
"reflect"
)
const unmarshalInputCtx key = "unmarshal_input_context"
// BuildUnmarshalerMap returns a map of unmarshal functions of the ExecutableContext
// to use with the WithUnmarshalerMap function.
func BuildUnmarshalerMap(unmarshaler ...interface{}) map[reflect.Type]reflect.Value {
maps := make(map[reflect.Type]reflect.Value)
for _, v := range unmarshaler {
ft := reflect.TypeOf(v)
if ft.Kind() == reflect.Func {
maps[ft.Out(0)] = reflect.ValueOf(v)
}
}
return maps
}
// WithUnmarshalerMap returns a new context with a map from input types to their unmarshaler functions.
func WithUnmarshalerMap(ctx context.Context, maps map[reflect.Type]reflect.Value) context.Context {
return context.WithValue(ctx, unmarshalInputCtx, maps)
}
// UnmarshalInputFromContext allows unmarshaling input object from a context.
func UnmarshalInputFromContext(ctx context.Context, raw, v interface{}) error {
m, ok := ctx.Value(unmarshalInputCtx).(map[reflect.Type]reflect.Value)
if m == nil || !ok {
return errors.New("graphql: the input context is empty")
}
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("graphql: input must be a non-nil pointer")
}
if fn, ok := m[rv.Elem().Type()]; ok {
res := fn.Call([]reflect.Value{
reflect.ValueOf(ctx),
reflect.ValueOf(raw),
})
if err := res[1].Interface(); err != nil {
return err.(error)
}
rv.Elem().Set(res[0])
return nil
}
return errors.New("graphql: no unmarshal function found")
}

Some files were not shown because too many files have changed in this diff Show More