// +build ignore

package main

import (
	"bytes"
	"go/format"
	"io"
	"log"
	"os"
	"text/template"
)

func main() {
	typeList := []string{
		// Expressions:
		"ArrayType",
		"BadExpr",
		"BasicLit",
		"BinaryExpr",
		"CallExpr",
		"ChanType",
		"CompositeLit",
		"Ellipsis",
		"FuncLit",
		"FuncType",
		"Ident",
		"IndexExpr",
		"InterfaceType",
		"KeyValueExpr",
		"MapType",
		"ParenExpr",
		"SelectorExpr",
		"SliceExpr",
		"StarExpr",
		"StructType",
		"TypeAssertExpr",
		"UnaryExpr",

		// Statements:
		"AssignStmt",
		"BadStmt",
		"BlockStmt",
		"BranchStmt",
		"CaseClause",
		"CommClause",
		"DeclStmt",
		"DeferStmt",
		"EmptyStmt",
		"ExprStmt",
		"ForStmt",
		"GoStmt",
		"IfStmt",
		"IncDecStmt",
		"LabeledStmt",
		"RangeStmt",
		"ReturnStmt",
		"SelectStmt",
		"SendStmt",
		"SwitchStmt",
		"TypeSwitchStmt",

		// Others:
		"Comment",
		"CommentGroup",
		"FieldList",
		"File",
		"Package",
	}

	astcastFile, err := os.Create("astcast.go")
	if err != nil {
		log.Fatal(err)
	}
	writeCode(astcastFile, typeList)
	astcastTestFile, err := os.Create("astcast_test.go")
	if err != nil {
		log.Fatal(err)
	}
	writeTests(astcastTestFile, typeList)
}

func generateCode(tmplText string, typeList []string) []byte {
	tmpl := template.Must(template.New("code").Parse(tmplText))
	var code bytes.Buffer
	tmpl.Execute(&code, typeList)
	prettyCode, err := format.Source(code.Bytes())
	if err != nil {
		panic(err)
	}
	return prettyCode
}

func writeCode(output io.Writer, typeList []string) {
	code := generateCode(`// Code generated by astcast_generate.go; DO NOT EDIT

// Package astcast wraps type assertion operations in such way that you don't have
// to worry about nil pointer results anymore.
package astcast

import (
	"go/ast"
)

// A set of sentinel nil-like values that are returned
// by all "casting" functions in case of failed type assertion.
var (
{{ range . }}
Nil{{.}} = &ast.{{.}}{}
{{- end }}
)

{{ range . }}
// To{{.}} returns x as a non-nil *ast.{{.}}.
// If ast.Node actually has such dynamic type, the result is
// identical to normal type assertion. In case if it has
// different type, the returned value is Nil{{.}}.
func To{{.}}(x ast.Node) *ast.{{.}} {
	if x, ok := x.(*ast.{{.}}); ok {
		return x
	}
	return Nil{{.}}
}
{{ end }}
`, typeList)
	output.Write(code)
}

func writeTests(output io.Writer, typeList []string) {
	code := generateCode(`// Code generated by astcast_generate.go; DO NOT EDIT

package astcast

import (
	"go/ast"
	"testing"
)

{{ range . }}
func TestTo{{.}}(t *testing.T) {
	// Test successfull cast.
	if x := To{{.}}(&ast.{{.}}{}); x == Nil{{.}} || x == nil {
		t.Error("expected successfull cast, got nil")
	}
	// Test nil cast.
	if x := To{{.}}(nil); x != Nil{{.}} {
		t.Error("nil node didn't resulted in a sentinel value return")
	}
	// Test unsuccessfull cast.
	{{- if (eq . "Ident") }}
		if x := To{{.}}(&ast.CallExpr{}); x != Nil{{.}} || x == nil {
			t.Errorf("expected unsuccessfull cast to return nil sentinel")
		}
	{{- else }}
		if x := To{{.}}(&ast.Ident{}); x != Nil{{.}} || x == nil {
			t.Errorf("expected unsuccessfull cast to return nil sentinel")
		}
	{{- end }}
}
{{ end }}
`, typeList)
	output.Write(code)
}