486 lines
9.6 KiB
Lua
486 lines
9.6 KiB
Lua
local driver = require "skynet.socketdriver"
|
|
local skynet = require "skynet"
|
|
local skynet_core = require "skynet.core"
|
|
local assert = assert
|
|
|
|
local BUFFER_LIMIT = 128 * 1024
|
|
local socket = {} -- api
|
|
local socket_pool = setmetatable( -- store all socket object
|
|
{},
|
|
{ __gc = function(p)
|
|
for id,v in pairs(p) do
|
|
driver.close(id)
|
|
p[id] = nil
|
|
end
|
|
end
|
|
}
|
|
)
|
|
|
|
local socket_message = {}
|
|
|
|
local function wakeup(s)
|
|
local co = s.co
|
|
if co then
|
|
s.co = nil
|
|
skynet.wakeup(co)
|
|
end
|
|
end
|
|
|
|
local function pause_socket(s, size)
|
|
if s.pause then
|
|
return
|
|
end
|
|
if size then
|
|
skynet.error(string.format("Pause socket (%d) size : %d" , s.id, size))
|
|
else
|
|
skynet.error(string.format("Pause socket (%d)" , s.id))
|
|
end
|
|
driver.pause(s.id)
|
|
s.pause = true
|
|
skynet.yield() -- there are subsequent socket messages in mqueue, maybe.
|
|
end
|
|
|
|
local function suspend(s)
|
|
assert(not s.co)
|
|
s.co = coroutine.running()
|
|
if s.pause then
|
|
skynet.error(string.format("Resume socket (%d)", s.id))
|
|
driver.start(s.id)
|
|
skynet.wait(s.co)
|
|
s.pause = nil
|
|
else
|
|
skynet.wait(s.co)
|
|
end
|
|
-- wakeup closing corouting every time suspend,
|
|
-- because socket.close() will wait last socket buffer operation before clear the buffer.
|
|
if s.closing then
|
|
skynet.wakeup(s.closing)
|
|
end
|
|
end
|
|
|
|
-- read skynet_socket.h for these macro
|
|
-- SKYNET_SOCKET_TYPE_DATA = 1
|
|
socket_message[1] = function(id, size, data)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
skynet.error("socket: drop package from " .. id)
|
|
driver.drop(data, size)
|
|
return
|
|
end
|
|
|
|
local sz = driver.push(s.buffer, s.pool, data, size)
|
|
local rr = s.read_required
|
|
local rrt = type(rr)
|
|
if rrt == "number" then
|
|
-- read size
|
|
if sz >= rr then
|
|
s.read_required = nil
|
|
if sz > BUFFER_LIMIT then
|
|
pause_socket(s, sz)
|
|
end
|
|
wakeup(s)
|
|
end
|
|
else
|
|
if s.buffer_limit and sz > s.buffer_limit then
|
|
skynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))
|
|
driver.close(id)
|
|
return
|
|
end
|
|
if rrt == "string" then
|
|
-- read line
|
|
if driver.readline(s.buffer,nil,rr) then
|
|
s.read_required = nil
|
|
if sz > BUFFER_LIMIT then
|
|
pause_socket(s, sz)
|
|
end
|
|
wakeup(s)
|
|
end
|
|
elseif sz > BUFFER_LIMIT and not s.pause then
|
|
pause_socket(s, sz)
|
|
end
|
|
end
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_CONNECT = 2
|
|
socket_message[2] = function(id, _ , addr)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
return
|
|
end
|
|
-- log remote addr
|
|
if not s.connected then -- resume may also post connect message
|
|
s.connected = true
|
|
wakeup(s)
|
|
end
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_CLOSE = 3
|
|
socket_message[3] = function(id)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
return
|
|
end
|
|
s.connected = false
|
|
wakeup(s)
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_ACCEPT = 4
|
|
socket_message[4] = function(id, newid, addr)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
driver.close(newid)
|
|
return
|
|
end
|
|
s.callback(newid, addr)
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_ERROR = 5
|
|
socket_message[5] = function(id, _, err)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
skynet.error("socket: error on unknown", id, err)
|
|
return
|
|
end
|
|
if s.callback then
|
|
skynet.error("socket: accpet error:", err)
|
|
return
|
|
end
|
|
if s.connected then
|
|
skynet.error("socket: error on", id, err)
|
|
elseif s.connecting then
|
|
s.connecting = err
|
|
end
|
|
s.connected = false
|
|
driver.shutdown(id)
|
|
|
|
wakeup(s)
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_UDP = 6
|
|
socket_message[6] = function(id, size, data, address)
|
|
local s = socket_pool[id]
|
|
if s == nil or s.callback == nil then
|
|
skynet.error("socket: drop udp package from " .. id)
|
|
driver.drop(data, size)
|
|
return
|
|
end
|
|
local str = skynet.tostring(data, size)
|
|
skynet_core.trash(data, size)
|
|
s.callback(str, address)
|
|
end
|
|
|
|
local function default_warning(id, size)
|
|
local s = socket_pool[id]
|
|
if not s then
|
|
return
|
|
end
|
|
skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
|
|
end
|
|
|
|
-- SKYNET_SOCKET_TYPE_WARNING
|
|
socket_message[7] = function(id, size)
|
|
local s = socket_pool[id]
|
|
if s then
|
|
local warning = s.on_warning or default_warning
|
|
warning(id, size)
|
|
end
|
|
end
|
|
|
|
skynet.register_protocol {
|
|
name = "socket",
|
|
id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
|
|
unpack = driver.unpack,
|
|
dispatch = function (_, _, t, ...)
|
|
socket_message[t](...)
|
|
end
|
|
}
|
|
|
|
local function connect(id, func)
|
|
local newbuffer
|
|
if func == nil then
|
|
newbuffer = driver.buffer()
|
|
end
|
|
local s = {
|
|
id = id,
|
|
buffer = newbuffer,
|
|
pool = newbuffer and {},
|
|
connected = false,
|
|
connecting = true,
|
|
read_required = false,
|
|
co = false,
|
|
callback = func,
|
|
protocol = "TCP",
|
|
}
|
|
assert(not socket_pool[id], "socket is not closed")
|
|
socket_pool[id] = s
|
|
suspend(s)
|
|
local err = s.connecting
|
|
s.connecting = nil
|
|
if s.connected then
|
|
return id
|
|
else
|
|
socket_pool[id] = nil
|
|
return nil, err
|
|
end
|
|
end
|
|
|
|
function socket.open(addr, port)
|
|
local id = driver.connect(addr,port)
|
|
return connect(id)
|
|
end
|
|
|
|
function socket.bind(os_fd)
|
|
local id = driver.bind(os_fd)
|
|
return connect(id)
|
|
end
|
|
|
|
function socket.stdin()
|
|
return socket.bind(0)
|
|
end
|
|
|
|
function socket.start(id, func)
|
|
driver.start(id)
|
|
return connect(id, func)
|
|
end
|
|
|
|
function socket.pause(id)
|
|
local s = socket_pool[id]
|
|
if s == nil or s.pause then
|
|
return
|
|
end
|
|
pause_socket(s)
|
|
end
|
|
|
|
function socket.shutdown(id)
|
|
local s = socket_pool[id]
|
|
if s then
|
|
-- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) later
|
|
driver.shutdown(id)
|
|
end
|
|
end
|
|
|
|
function socket.close_fd(id)
|
|
assert(socket_pool[id] == nil,"Use socket.close instead")
|
|
driver.close(id)
|
|
end
|
|
|
|
function socket.close(id)
|
|
local s = socket_pool[id]
|
|
if s == nil then
|
|
return
|
|
end
|
|
if s.connected then
|
|
driver.close(id)
|
|
if s.co then
|
|
-- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately
|
|
-- wait reading coroutine read the buffer.
|
|
assert(not s.closing)
|
|
s.closing = coroutine.running()
|
|
skynet.wait(s.closing)
|
|
else
|
|
suspend(s)
|
|
end
|
|
s.connected = false
|
|
end
|
|
assert(s.lock == nil or next(s.lock) == nil)
|
|
socket_pool[id] = nil
|
|
end
|
|
|
|
function socket.read(id, sz)
|
|
local s = socket_pool[id]
|
|
assert(s)
|
|
if sz == nil then
|
|
-- read some bytes
|
|
local ret = driver.readall(s.buffer, s.pool)
|
|
if ret ~= "" then
|
|
return ret
|
|
end
|
|
if not s.connected then
|
|
return false, ret
|
|
end
|
|
assert(not s.read_required)
|
|
s.read_required = 0
|
|
suspend(s)
|
|
ret = driver.readall(s.buffer, s.pool)
|
|
if ret ~= "" then
|
|
return ret
|
|
else
|
|
return false, ret
|
|
end
|
|
end
|
|
|
|
local ret = driver.pop(s.buffer, s.pool, sz)
|
|
if ret then
|
|
return ret
|
|
end
|
|
if not s.connected then
|
|
return false, driver.readall(s.buffer, s.pool)
|
|
end
|
|
|
|
assert(not s.read_required)
|
|
s.read_required = sz
|
|
suspend(s)
|
|
ret = driver.pop(s.buffer, s.pool, sz)
|
|
if ret then
|
|
return ret
|
|
else
|
|
return false, driver.readall(s.buffer, s.pool)
|
|
end
|
|
end
|
|
|
|
function socket.readall(id)
|
|
local s = socket_pool[id]
|
|
assert(s)
|
|
if not s.connected then
|
|
local r = driver.readall(s.buffer, s.pool)
|
|
return r ~= "" and r
|
|
end
|
|
assert(not s.read_required)
|
|
s.read_required = true
|
|
suspend(s)
|
|
assert(s.connected == false)
|
|
return driver.readall(s.buffer, s.pool)
|
|
end
|
|
|
|
function socket.readline(id, sep)
|
|
sep = sep or "\n"
|
|
local s = socket_pool[id]
|
|
assert(s)
|
|
local ret = driver.readline(s.buffer, s.pool, sep)
|
|
if ret then
|
|
return ret
|
|
end
|
|
if not s.connected then
|
|
return false, driver.readall(s.buffer, s.pool)
|
|
end
|
|
assert(not s.read_required)
|
|
s.read_required = sep
|
|
suspend(s)
|
|
if s.connected then
|
|
return driver.readline(s.buffer, s.pool, sep)
|
|
else
|
|
return false, driver.readall(s.buffer, s.pool)
|
|
end
|
|
end
|
|
|
|
function socket.block(id)
|
|
local s = socket_pool[id]
|
|
if not s or not s.connected then
|
|
return false
|
|
end
|
|
assert(not s.read_required)
|
|
s.read_required = 0
|
|
suspend(s)
|
|
return s.connected
|
|
end
|
|
|
|
socket.write = assert(driver.send)
|
|
socket.lwrite = assert(driver.lsend)
|
|
socket.header = assert(driver.header)
|
|
|
|
function socket.invalid(id)
|
|
return socket_pool[id] == nil
|
|
end
|
|
|
|
function socket.disconnected(id)
|
|
local s = socket_pool[id]
|
|
if s then
|
|
return not(s.connected or s.connecting)
|
|
end
|
|
end
|
|
|
|
function socket.listen(host, port, backlog)
|
|
if port == nil then
|
|
host, port = string.match(host, "([^:]+):(.+)$")
|
|
port = tonumber(port)
|
|
end
|
|
return driver.listen(host, port, backlog)
|
|
end
|
|
|
|
function socket.lock(id)
|
|
local s = socket_pool[id]
|
|
assert(s)
|
|
local lock_set = s.lock
|
|
if not lock_set then
|
|
lock_set = {}
|
|
s.lock = lock_set
|
|
end
|
|
if #lock_set == 0 then
|
|
lock_set[1] = true
|
|
else
|
|
local co = coroutine.running()
|
|
table.insert(lock_set, co)
|
|
skynet.wait(co)
|
|
end
|
|
end
|
|
|
|
function socket.unlock(id)
|
|
local s = socket_pool[id]
|
|
assert(s)
|
|
local lock_set = assert(s.lock)
|
|
table.remove(lock_set,1)
|
|
local co = lock_set[1]
|
|
if co then
|
|
skynet.wakeup(co)
|
|
end
|
|
end
|
|
|
|
-- abandon use to forward socket id to other service
|
|
-- you must call socket.start(id) later in other service
|
|
function socket.abandon(id)
|
|
local s = socket_pool[id]
|
|
if s then
|
|
s.connected = false
|
|
wakeup(s)
|
|
socket_pool[id] = nil
|
|
end
|
|
end
|
|
|
|
function socket.limit(id, limit)
|
|
local s = assert(socket_pool[id])
|
|
s.buffer_limit = limit
|
|
end
|
|
|
|
---------------------- UDP
|
|
|
|
local function create_udp_object(id, cb)
|
|
assert(not socket_pool[id], "socket is not closed")
|
|
socket_pool[id] = {
|
|
id = id,
|
|
connected = true,
|
|
protocol = "UDP",
|
|
callback = cb,
|
|
}
|
|
end
|
|
|
|
function socket.udp(callback, host, port)
|
|
local id = driver.udp(host, port)
|
|
create_udp_object(id, callback)
|
|
return id
|
|
end
|
|
|
|
function socket.udp_connect(id, addr, port, callback)
|
|
local obj = socket_pool[id]
|
|
if obj then
|
|
assert(obj.protocol == "UDP")
|
|
if callback then
|
|
obj.callback = callback
|
|
end
|
|
else
|
|
create_udp_object(id, callback)
|
|
end
|
|
driver.udp_connect(id, addr, port)
|
|
end
|
|
|
|
socket.sendto = assert(driver.udp_send)
|
|
socket.udp_address = assert(driver.udp_address)
|
|
socket.netstat = assert(driver.info)
|
|
|
|
function socket.warning(id, callback)
|
|
local obj = socket_pool[id]
|
|
assert(obj)
|
|
obj.on_warning = callback
|
|
end
|
|
|
|
return socket
|