//args: -Esqlclosecheck
package testdata

import (
	"context"
	"database/sql"
	"log"
	"strings"

	_ "github.com/go-sql-driver/mysql"
	"github.com/jmoiron/sqlx"
)

var (
	ctx    context.Context
	db     *sql.DB
	dbx    *sqlx.DB
	age    = 27
	userID = 43
)

func rowsCorrectDeferBlock() {

	rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}

	defer func() {
		err := rows.Close()
		if err != nil {
			log.Print("problem closing rows")
		}
	}()

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)
}

func rowsCorrectDefer() {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)
}

func rowsMissingClose() {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age) // ERROR "Rows/Stmt was not closed"
	if err != nil {
		log.Fatal(err)
	}
	// defer rows.Close()

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)
}

func rowsNonDeferClose() {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)

	rows.Close() // ERROR "Close should use defer"
}

func rowsPassedAndClosed() {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users")
	if err != nil {
		log.Fatal(err)
	}

	rowsClosedPassed(rows)
}

func rowsClosedPassed(rows *sql.Rows) {
	rows.Close()
}

func rowsPassedAndNotClosed(rows *sql.Rows) {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users")
	if err != nil {
		log.Fatal(err)
	}

	rowsDontClosedPassed(rows)
}

func rowsDontClosedPassed(*sql.Rows) {

}

func rowsReturn() (*sql.Rows, error) {
	rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}
	return rows, nil
}

func rowsReturnShort() (*sql.Rows, error) {
	return db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
}

func stmtCorrectDeferBlock() {
	// In normal use, create one Stmt when your process starts.
	stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
	if err != nil {
		log.Fatal(err)
	}
	defer func() {
		err := stmt.Close()
		if err != nil {
			log.Print("problem closing stmt")
		}
	}()

	// Then reuse it each time you need to issue the query.
	var username string
	err = stmt.QueryRowContext(ctx, userID).Scan(&username)
	switch {
	case err == sql.ErrNoRows:
		log.Fatalf("no user with id %d", userID)
	case err != nil:
		log.Fatal(err)
	default:
		log.Printf("username is %s\n", username)
	}
}

func stmtCorrectDefer() {
	// In normal use, create one Stmt when your process starts.
	stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
	if err != nil {
		log.Fatal(err)
	}
	defer stmt.Close()

	// Then reuse it each time you need to issue the query.
	var username string
	err = stmt.QueryRowContext(ctx, userID).Scan(&username)
	switch {
	case err == sql.ErrNoRows:
		log.Fatalf("no user with id %d", userID)
	case err != nil:
		log.Fatal(err)
	default:
		log.Printf("username is %s\n", username)
	}
}

func stmtMissingClose() {
	// In normal use, create one Stmt when your process starts.
	stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?") // ERROR "Rows/Stmt was not closed"
	if err != nil {
		log.Fatal(err)
	}
	// defer stmt.Close()

	// Then reuse it each time you need to issue the query.
	var username string
	err = stmt.QueryRowContext(ctx, userID).Scan(&username)
	switch {
	case err == sql.ErrNoRows:
		log.Fatalf("no user with id %d", userID)
	case err != nil:
		log.Fatal(err)
	default:
		log.Printf("username is %s\n", username)
	}
}

func stmtNonDeferClose() {
	// In normal use, create one Stmt when your process starts.
	stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
	if err != nil {
		log.Fatal(err)
	}

	// Then reuse it each time you need to issue the query.
	var username string
	err = stmt.QueryRowContext(ctx, userID).Scan(&username)
	switch {
	case err == sql.ErrNoRows:
		log.Fatalf("no user with id %d", userID)
	case err != nil:
		log.Fatal(err)
	default:
		log.Printf("username is %s\n", username)
	}

	stmt.Close() // ERROR "Close should use defer"
}

func stmtReturn() (*sql.Stmt, error) {
	stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
	if err != nil {
		return nil, err
	}

	return stmt, nil
}

func stmtReturnShort() (*sql.Stmt, error) {
	return db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
}

func sqlxCorrectDefer() {
	rows, err := dbx.Queryx("SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}

	defer rows.Close()

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)
}

func sqlxNonDeferClose() {
	rows, err := dbx.Queryx("SELECT name FROM users WHERE age=?", age)
	if err != nil {
		log.Fatal(err)
	}

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)

	rows.Close() // ERROR "Close should use defer"
}

func sqlxMissingClose() {
	rows, err := dbx.Queryx("SELECT name FROM users WHERE age=?", age) // ERROR "Rows/Stmt was not closed"
	if err != nil {
		log.Fatal(err)
	}

	// defer rows.Close()

	names := make([]string, 0)
	for rows.Next() {
		var name string
		if err := rows.Scan(&name); err != nil {
			log.Fatal(err)
		}
		names = append(names, name)
	}

	// Check for errors from iterating over rows.
	if err := rows.Err(); err != nil {
		log.Fatal(err)
	}
	log.Printf("%s are %d years old", strings.Join(names, ", "), age)
}

func sqlxReturnRows() (*sqlx.Rows, error) {
	rows, err := dbx.Queryx("SELECT name FROM users WHERE age=?", age)
	if err != nil {
		return nil, err
	}

	return rows, nil
}