package mysql import ( "database/sql" "database/sql/driver" "strconv" "strings" "sync" "time" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/statemdl" _ "github.com/go-sql-driver/mysql" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/sjsonhelpermdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/configmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/constantmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gocraft/dbr/v2" ) // Hold a single global connection (pooling provided by sql driver) var sqlConnections map[string]*dbr.Connection var connectionError error var sqlOnce sync.Once var config tomlConfig var defaultHost string // MySQLConnection - MySQLConnection type MySQLConnection struct { HostName string `json:"hostName" bson:"hostName"` Server string `json:"server" bson:"server"` Port int `json:"port" bson:"port"` Username string `json:"username" bson:"username"` Password string `json:"password" bson:"password"` Protocol string `json:"protocol" bson:"protocol"` Database string `json:"database" bson:"database"` Parameters []param `json:"params" bson:"params"` MaxIdleConns int `json:"maxIdleConns" bson:"maxIdleConns"` MaxOpenConns int `json:"maxOpenConns" bson:"maxOpenConns"` ConnMaxLifetime time.Duration `json:"connMaxLifetime" bson:"connMaxLifetime"` IsDefault bool `json:"isDefault" bson:"isDefault"` IsDisabled bool `json:"isDisabled" bson:"isDisabled"` } // InitUsingJSON - InitUsingJSON func InitUsingJSON(configs []MySQLConnection) error { sqlOnce.Do(func() { sqlConnections = make(map[string]*dbr.Connection) for _, connectionDetails := range configs { if connectionDetails.IsDisabled { continue } connection, err := InitConnection(connectionDetails) if errormdl.CheckErr1(err) != nil { loggermdl.LogError("Init dbr.Open Err : ", err) connectionError = err return } pingError := connection.Ping() if errormdl.CheckErr1(pingError) != nil { loggermdl.LogError(pingError) connectionError = pingError return } sqlConnections[connectionDetails.HostName] = connection if connectionDetails.IsDefault { defaultHost = connectionDetails.HostName } } }) return connectionError } // InitConnection - InitConnection func InitConnection(connectionDetails MySQLConnection) (*dbr.Connection, error) { paramsString := strings.Builder{} if len(connectionDetails.Parameters) > 0 { for paramIndex, param := range connectionDetails.Parameters { if paramsString.String() == "" { paramsString.WriteString("?") } paramsString.WriteString(param.ParamKey) paramsString.WriteString("=") paramsString.WriteString(param.ParamValue) hasNextParam := paramIndex+1 < len(connectionDetails.Parameters) if hasNextParam { paramsString.WriteString("&") } } } conStr := strings.Builder{} conStr.WriteString(connectionDetails.Username) conStr.WriteString(":") conStr.WriteString(connectionDetails.Password) conStr.WriteString("@") conStr.WriteString(connectionDetails.Protocol) conStr.WriteString("(") conStr.WriteString(connectionDetails.Server) if connectionDetails.Port <= 0 || strings.TrimSpace(strconv.Itoa(connectionDetails.Port)) == "" { conStr.WriteString(":3306") // mysql default port is 3306 } else { conStr.WriteString(":") conStr.WriteString(strconv.Itoa(connectionDetails.Port)) } conStr.WriteString(")/") conStr.WriteString(connectionDetails.Database) conStr.WriteString(paramsString.String()) connection, err := dbr.Open("mysql", conStr.String(), nil) if errormdl.CheckErr1(err) != nil { loggermdl.LogError("Init dbr.Open Err : ", err) return nil, err } if connectionDetails.MaxIdleConns == 0 { connectionDetails.MaxIdleConns = constantmdl.MAX_IDLE_CONNECTIONS // default is 2 } if connectionDetails.MaxOpenConns == 0 { connectionDetails.MaxOpenConns = constantmdl.MAX_OPEN_CONNECTIONS // default there's no limit } if connectionDetails.ConnMaxLifetime == 0 { connectionDetails.ConnMaxLifetime = constantmdl.CONNECTION_MAX_LIFETIME } connection.SetMaxIdleConns(connectionDetails.MaxIdleConns) connection.SetMaxOpenConns(connectionDetails.MaxOpenConns) connection.SetConnMaxLifetime(connectionDetails.ConnMaxLifetime) return connection, nil } type param struct { ParamKey string `json:"paramkey" bson:"paramkey"` ParamValue string `json:"paramvalue" bson:"paramvalue"` } type tomlConfig struct { MysqlHosts map[string]MySQLConnection } // Init initializes MYSQL Connections for given toml file func Init(tomlFilepath string, defaultHostName string) (map[string]*dbr.Connection, error) { sqlOnce.Do(func() { sqlConnections = make(map[string]*dbr.Connection) _, err := configmdl.InitConfig(tomlFilepath, &config) if errormdl.CheckErr(err) != nil { loggermdl.LogError("Init InitConfig Err : ", err) connectionError = err return } for connectionName, connectionDetails := range config.MysqlHosts { paramsString := "" if len(connectionDetails.Parameters) > 0 { for paramIndex, param := range connectionDetails.Parameters { if paramsString == "" { paramsString = "?" } paramsString = paramsString + param.ParamKey + "=" + param.ParamValue hasNextParam := paramIndex+1 < len(connectionDetails.Parameters) if hasNextParam { paramsString = paramsString + "&" } } } connection, err := dbr.Open("mysql", connectionDetails.Username+":"+connectionDetails.Password+"@"+connectionDetails.Protocol+"("+connectionDetails.Server+")/"+connectionDetails.Database+paramsString, nil) if errormdl.CheckErr1(err) != nil { loggermdl.LogError("Init dbr.Open Err : ", err) connectionError = err return } if connectionDetails.MaxIdleConns == 0 { connectionDetails.MaxIdleConns = constantmdl.MAX_IDLE_CONNECTIONS // default is 2 } if connectionDetails.MaxOpenConns == 0 { connectionDetails.MaxOpenConns = constantmdl.MAX_OPEN_CONNECTIONS // default there's no limit } if connectionDetails.ConnMaxLifetime == 0 { connectionDetails.ConnMaxLifetime = constantmdl.CONNECTION_MAX_LIFETIME } connection.SetMaxIdleConns(connectionDetails.MaxIdleConns) connection.SetMaxOpenConns(connectionDetails.MaxOpenConns) connection.SetConnMaxLifetime(connectionDetails.ConnMaxLifetime) sqlConnections[connectionName] = connection } defaultHost = defaultHostName }) return sqlConnections, errormdl.CheckErr2(connectionError) } // GetMYSQLConnection - func GetMYSQLConnection(connectionName string) (*dbr.Connection, error) { if errormdl.CheckBool(sqlConnections == nil) { loggermdl.LogError("GetMYSQLConnection Err : ", errormdl.Wrap("MYSQL_INIT_NOT_DONE")) return nil, errormdl.Wrap("MYSQL_INIT_NOT_DONE") } if connectionName == "" { if instance, keyExist := sqlConnections[defaultHost]; keyExist { statemdl.MySQLHits() return instance, nil } } if session, keyExist := sqlConnections[connectionName]; keyExist { statemdl.MySQLHits() return session, nil } loggermdl.LogError("GetMYSQLConnection Err : ", errormdl.Wrap("Connection not found for host: "+connectionName)) return nil, errormdl.Wrap("Connection not found for host: " + connectionName) } // MysqlDAO Mysql DAO struct type MySQLDAO struct { hostName string } // GetMysqlDAO return Mysql DAO instance func GetMySQLDAO() *MySQLDAO { return &MySQLDAO{ hostName: defaultHost, } } // GetMysqlDAOWithHost return Mysql DAO instance func GetMySQLDAOWithHost(host string) *MySQLDAO { return &MySQLDAO{ hostName: host, } } // ExecQuery - ExecQuery func (md *MySQLDAO) ExecQuery(query string, args ...interface{}) (string, error) { connection, connectionError := GetMYSQLConnection(md.hostName) if errormdl.CheckErr(connectionError) != nil { loggermdl.LogError("SaveUpdateOrDelete GetMYSQLConnection Err : ", connectionError) return "", errormdl.CheckErr(connectionError) } pingError := connection.Ping() if errormdl.CheckErr(pingError) != nil && pingError != driver.ErrBadConn { loggermdl.LogError(pingError) return "", errormdl.CheckErr(pingError) } result, execError := connection.Exec(query, args...) if errormdl.CheckErr(execError) != nil { loggermdl.LogError(execError) return "", errormdl.CheckErr(execError) } _, affectError := result.RowsAffected() if errormdl.CheckErr(affectError) != nil { loggermdl.LogError(affectError) return "", errormdl.CheckErr(affectError) } ID, err := result.LastInsertId() if errormdl.CheckErr(err) != nil { loggermdl.LogError(err) return "", errormdl.CheckErr(err) } return strconv.Itoa(int(ID)), nil } // SelectQuery - SelectQuery func (md *MySQLDAO) SelectQuery(query string, args ...interface{}) (*gjson.Result, error) { connection, connectionError := GetMYSQLConnection(md.hostName) if errormdl.CheckErr(connectionError) != nil { loggermdl.LogError("SaveUpdateOrDelete GetMYSQLConnection Err : ", connectionError) return nil, errormdl.CheckErr(connectionError) } // loggermdl.LogSpot(connection) pingError := connection.Ping() if errormdl.CheckErr(pingError) != nil && pingError != driver.ErrBadConn { loggermdl.LogError(pingError) return nil, errormdl.CheckErr(pingError) } rows, queryError := connection.Query(query, args...) if errormdl.CheckErr(queryError) != nil { loggermdl.LogError(queryError) return nil, errormdl.CheckErr(queryError) } defer rows.Close() columns, err := rows.Columns() if errormdl.CheckErr2(err) != nil { loggermdl.LogError("GetAllData rows.Columns() Err : ", err) return nil, errormdl.CheckErr2(err) } values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) tableData := "[]" for rows.Next() { for i := 0; i < len(columns); i++ { valuePtrs[i] = &values[i] } rows.Scan(valuePtrs...) data, err := sjsonhelpermdl.SetMultiple("", columns, values) if errormdl.CheckErr3(err) != nil { loggermdl.LogError("GetAllData sjson.Set Err : ", err) return nil, errormdl.CheckErr3(err) } tableData, err = sjson.Set(tableData, "-1", gjson.Parse(data).Value()) if errormdl.CheckErr3(err) != nil { loggermdl.LogError("GetAllData sjson.Set Err : ", err) return nil, errormdl.CheckErr3(err) } } resultSet := gjson.Parse(tableData) return &resultSet, nil } // ExecTxQuery - ExecTxQuery func ExecTxQuery(query string, tx *sql.Tx, args ...interface{}) (string, error) { result, execError := tx.Exec(query, args...) if errormdl.CheckErr(execError) != nil { loggermdl.LogError(execError) return "", errormdl.CheckErr(execError) } _, affectError := result.RowsAffected() if errormdl.CheckErr(affectError) != nil { loggermdl.LogError(affectError) return "", errormdl.CheckErr(affectError) } ID, err := result.LastInsertId() if errormdl.CheckErr(err) != nil { loggermdl.LogError(err) return "", errormdl.CheckErr(err) } return strconv.Itoa(int(ID)), nil } // SelectTxQuery - SelectTxQuery func SelectTxQuery(query string, tx *sql.Tx, args ...interface{}) (*gjson.Result, error) { rows, queryError := tx.Query(query, args...) if errormdl.CheckErr(queryError) != nil { loggermdl.LogError(queryError) return nil, errormdl.CheckErr(queryError) } defer rows.Close() columns, err := rows.Columns() if errormdl.CheckErr2(err) != nil { loggermdl.LogError("GetAllData rows.Columns() Err : ", err) return nil, errormdl.CheckErr2(err) } values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) tableData := "[]" for rows.Next() { for i := 0; i < len(columns); i++ { valuePtrs[i] = &values[i] } rows.Scan(valuePtrs...) data, err := sjsonhelpermdl.SetMultiple("", columns, values) if errormdl.CheckErr3(err) != nil { loggermdl.LogError("GetAllData sjson.Set Err : ", err) return nil, errormdl.CheckErr3(err) } tableData, err = sjson.Set(tableData, "-1", gjson.Parse(data).Value()) if errormdl.CheckErr3(err) != nil { loggermdl.LogError("GetAllData sjson.Set Err : ", err) return nil, errormdl.CheckErr3(err) } } resultSet := gjson.Parse(tableData) return &resultSet, nil }