-- Copyright (C) 2012 Yichun Zhang (agentzh) -- Copyright (C) 2014 Chang Feng -- This file is modified version from https://github.com/openresty/lua-resty-mysql -- The license is under the BSD license. -- Modified by Cloud Wu (remove bit32 for lua 5.3) -- protocol detail: https://mariadb.com/kb/en/clientserver-protocol/ local socketchannel = require "skynet.socketchannel" local crypt = require "skynet.crypt" local sub = string.sub local strgsub = string.gsub local strformat = string.format local strbyte = string.byte local strchar = string.char local strrep = string.rep local strunpack = string.unpack local strpack = string.pack local sha1 = crypt.sha1 local setmetatable = setmetatable local error = error local tonumber = tonumber local tointeger = math.tointeger local _M = {_VERSION = "0.14"} -- the following charset map is generated from the following mysql query: -- SELECT CHARACTER_SET_NAME, ID -- FROM information_schema.collations -- WHERE IS_DEFAULT = 'Yes' ORDER BY id; local CHARSET_MAP = { _default = 0, big5 = 1, dec8 = 3, cp850 = 4, hp8 = 6, koi8r = 7, latin1 = 8, latin2 = 9, swe7 = 10, ascii = 11, ujis = 12, sjis = 13, hebrew = 16, tis620 = 18, euckr = 19, koi8u = 22, gb2312 = 24, greek = 25, cp1250 = 26, gbk = 28, latin5 = 30, armscii8 = 32, utf8 = 33, ucs2 = 35, cp866 = 36, keybcs2 = 37, macce = 38, macroman = 39, cp852 = 40, latin7 = 41, utf8mb4 = 45, cp1251 = 51, utf16 = 54, utf16le = 56, cp1256 = 57, cp1257 = 59, utf32 = 60, binary = 63, geostd8 = 92, cp932 = 95, eucjpms = 97, gb18030 = 248 } -- constants local COM_QUERY = "\x03" local COM_PING = "\x0e" local COM_STMT_PREPARE = "\x16" local COM_STMT_EXECUTE = "\x17" local COM_STMT_CLOSE = "\x19" local COM_STMT_RESET = "\x1a" local CURSOR_TYPE_NO_CURSOR = 0x00 local SERVER_MORE_RESULTS_EXISTS = 8 local mt = {__index = _M} -- mysql field value type converters local converters = {} for i = 0x01, 0x05 do -- tiny, short, long, float, double converters[i] = tonumber end converters[0x08] = tonumber -- long long converters[0x09] = tonumber -- int24 converters[0x0d] = tonumber -- year converters[0xf6] = tonumber -- newdecimal local function _get_byte1(data, i) return strbyte(data, i), i + 1 end local function _get_int1(data, i, is_signed) if not is_signed then return strunpack("= 0 and first <= 250 then return first, pos + 1 end if first == 251 then return nil, pos + 1 end if first == 252 then pos = pos + 1 return _get_byte2(data, pos) end if first == 253 then pos = pos + 1 return _get_byte3(data, pos) end if first == 254 then pos = pos + 1 return _get_byte8(data, pos) end return false, pos + 1 end local function _set_length_coded_bin(n) if n < 251 then return strchar(n) end if n < (1 << 16) then return strpack(" 0 then local f, ts, vs local types_buf = "" local values_buf = "" --生成NULL位图 local null_count = (arg_num + 7) // 8 local null_map = "" local field_index = 1 for i = 1, null_count do local byte = 0 for j = 0, 7 do if field_index < arg_num then if args[field_index] == nil then byte = byte | (1 << j) else byte = byte | (0 << j) end end field_index = field_index + 1 end null_map = null_map .. strchar(byte) end for i = 1, arg_num do local v = args[i] f = store_types[type(v)] if not f then error("invalid parameter type", type(v)) end ts, vs = f(v) types_buf = types_buf .. ts values_buf = values_buf .. vs end cmd_packet = cmd_packet .. null_map .. strchar(0x01) .. types_buf .. values_buf end return _compose_packet(self, cmd_packet) end local function read_result(self, sock) local packet, typ, err = _recv_packet(self, sock) if not packet then return nil, err --error( err ) end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end if typ == "OK" then local res = _parse_ok_packet(packet) if res and res.server_status & SERVER_MORE_RESULTS_EXISTS ~= 0 then return res, "again" end return res end if typ ~= "DATA" then return nil, "packet type " .. typ .. " not supported" --error( "packet type " .. typ .. " not supported" ) end -- typ == 'DATA' local field_count, extra = _parse_result_set_header_packet(packet) local cols = {} for i = 1, field_count do local col, err, errno, sqlstate = _recv_field_packet(self, sock) if not col then return nil, err, errno, sqlstate --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end cols[i] = col end local packet, typ, err = _recv_packet(self, sock) if not packet then --error( err) return nil, err end if typ ~= "EOF" then --error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" ) return nil, "unexpected packet type " .. typ .. " while eof packet is " .. "expected" end -- typ == 'EOF' local compact = self.compact local rows = {} local i = 0 while true do packet, typ, err = _recv_packet(self, sock) if not packet then --error (err) return nil, err end if typ == "EOF" then local warning_count, status_flags = _parse_eof_packet(packet) if status_flags & SERVER_MORE_RESULTS_EXISTS ~= 0 then return rows, "again" end break end -- if typ ~= 'DATA' then -- return nil, 'bad row packet type: ' .. typ -- end -- typ == 'DATA' i = i + 1 rows[i] = _parse_row_data_packet(packet, cols, compact) end return rows end local function _query_resp(self) return function(sock) local res, err, errno, sqlstate = read_result(self, sock) if not res then local badresult = {} badresult.badresult = true badresult.err = err badresult.errno = errno badresult.sqlstate = sqlstate return true, badresult end if err ~= "again" then return true, res end local multiresultset = {res} multiresultset.multiresultset = true local i = 2 while err == "again" do res, err, errno, sqlstate = read_result(self, sock) if not res then multiresultset.badresult = true multiresultset.err = err multiresultset.errno = errno multiresultset.sqlstate = sqlstate return true, multiresultset end multiresultset[i] = res i = i + 1 end return true, multiresultset end end function _M.connect(opts) local self = setmetatable({}, mt) local max_packet_size = opts.max_packet_size if not max_packet_size then max_packet_size = 1024 * 1024 -- default 1 MB end self._max_packet_size = max_packet_size self.compact = opts.compact_arrays local database = opts.database or "" local user = opts.user or "" local password = opts.password or "" local charset = CHARSET_MAP[opts.charset or "_default"] local channel = socketchannel.channel { host = opts.host, port = opts.port or 3306, auth = _mysql_login(self, user, password, charset, database, opts.on_connect), overload = opts.overload } self.sockchannel = channel -- try connect first only once channel:connect(true) return self end function _M.disconnect(self) self.sockchannel:close() setmetatable(self, nil) end function _M.query(self, query) local querypacket = _compose_query(self, query) local sockchannel = self.sockchannel if not self.query_resp then self.query_resp = _query_resp(self) end return sockchannel:request(querypacket, self.query_resp) end local function read_prepare_result(self, sock) local resp = {} local packet, typ, err = _recv_packet(self, sock) if not packet then resp.badresult = true resp.errno = 300101 resp.err = err return false, resp end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) resp.badresult = true resp.errno = errno resp.err = msg resp.sqlstate = sqlstate return true, resp end --第一节只能是OK if typ ~= "OK" then resp.badresult = true resp.errno = 300201 resp.err = "first typ must be OK,now" .. typ return false, resp end resp.prepare_id, resp.field_count, resp.param_count, resp.warning_count = strunpack(" 0 then local param = _recv_field_packet(self, sock) while param do table.insert(resp.params, param) param = _recv_field_packet(self, sock) end end if resp.field_count > 0 then local field = _recv_field_packet(self, sock) while field do table.insert(resp.fields, field) field = _recv_field_packet(self, sock) end end return true, resp end local function _prepare_resp(self, sql) return function(sock) return read_prepare_result(self, sock, sql) end end -- 注册预处理语句 function _M.prepare(self, sql) local querypacket = _compose_stmt_prepare(self, sql) local sockchannel = self.sockchannel if not self.prepare_resp then self.prepare_resp = _prepare_resp(self) end return sockchannel:request(querypacket, self.prepare_resp) end local function _get_datetime(data, pos) local len, year, month, day, hour, minute, second local value len, pos = _from_length_coded_bin(data, pos) if len == 7 then year, month, day, hour, minute, second, pos = string.unpack(" 2 then if byte & (1 << j) == 0 then null_fields[field_index - 2] = false else null_fields[field_index - 2] = true end end field_index = field_index + 1 end end local row = {} local parser for i = 1, ncols do local col = cols[i] local typ = col.type local name = col.name if not null_fields[i] then parser = _binary_parser[typ] if not parser then error("_parse_row_data_binary()error,unsupported field type " .. typ) end value, pos = parser(data, pos, col.is_signed) if compact then row[i] = value else row[name] = value end end end return row end local function read_execute_result(self, sock) local packet, typ, err = _recv_packet(self, sock) if not packet then return nil, err --error( err ) end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end if typ == "OK" then local res = _parse_ok_packet(packet) if res and res.server_status & SERVER_MORE_RESULTS_EXISTS ~= 0 then return res, "again" end return res end if typ ~= "DATA" then return nil, "packet type " .. typ .. " not supported" --error( "packet type " .. typ .. " not supported" ) end -- typ == 'DATA' local field_count, extra = _parse_result_set_header_packet(packet) local cols = {} local col while true do packet, typ, err = _recv_packet(self, sock) if typ == "EOF" then local warning_count, status_flags = _parse_eof_packet(packet) break end col = _parse_field_packet(packet) if not col then break --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end table.insert(cols, col) end --没有记录集返回 if #cols < 1 then return {} end local compact = self.compact local rows = {} local row while true do packet, typ, err = _recv_packet(self, sock) if typ == "EOF" then local warning_count, status_flags = _parse_eof_packet(packet) if status_flags & SERVER_MORE_RESULTS_EXISTS ~= 0 then return rows, "again" end break end row = _parse_row_data_binary(packet, cols, compact) if not col then break end table.insert(rows, row) end return rows end local function _execute_resp(self) return function(sock) local res, err, errno, sqlstate = read_execute_result(self, sock) if not res then local badresult = {} badresult.badresult = true badresult.err = err badresult.errno = errno badresult.sqlstate = sqlstate return true, badresult end if err ~= "again" then return true, res end local mulitresultset = {res} mulitresultset.mulitresultset = true local i = 2 while err == "again" do res, err, errno, sqlstate = read_execute_result(self, sock) if not res then mulitresultset.badresult = true mulitresultset.err = err mulitresultset.errno = errno mulitresultset.sqlstate = sqlstate return true, mulitresultset end mulitresultset[i] = res i = i + 1 end return true, mulitresultset end end --[[ 执行预处理语句 失败返回字段 errno badresult sqlstate err ]] function _M.execute(self, stmt, ...) local querypacket, er = _compose_stmt_execute(self, stmt, CURSOR_TYPE_NO_CURSOR, {...}) if not querypacket then return { badresult = true, errno = 30902, err = er } end local sockchannel = self.sockchannel if not self.execute_resp then self.execute_resp = _execute_resp(self) end return sockchannel:request(querypacket, self.execute_resp) end local function _compose_stmt_reset(self, stmt) self.packet_no = -1 local cmd_packet = strpack("c1