view src/luan/modules/Rpc.luan @ 1506:d80395468b4e

ssl security in code
author Franklin Schmidt <fschmidt@gmail.com>
date Fri, 15 May 2020 18:29:47 -0600
parents 5b8f76e26ab7
children 0ba144491a42
line wrap: on
line source
require "java"
local Socket = require "java:java.net.Socket"
local ServerSocket = require "java:java.net.ServerSocket"
local IoUtils = require "java:goodjava.io.IoUtils"
local RpcClient = require "java:goodjava.rpc.RpcClient"
local RpcServer = require "java:goodjava.rpc.RpcServer"
local RpcCall = require "java:goodjava.rpc.RpcCall"
local RpcResult = require "java:goodjava.rpc.RpcResult"
local RpcException = require "java:goodjava.rpc.RpcException"
local JavaRpc = require "java:goodjava.rpc.Rpc"
local LuanJava = require "java:luan.Luan"
local JavaUtils = require "java:luan.modules.Utils"
local IoLuan = require "java:luan.modules.IoLuan"
local ByteArrayInputStream = require "java:java.io.ByteArrayInputStream"
local Luan = require "luan:Luan.luan"
local error = Luan.error
local set_metatable = Luan.set_metatable or error()
local try = Luan.try or error()
local ipairs = Luan.ipairs or error()
local type = Luan.type or error()
local Boot = require "luan:Boot.luan"
local Io = require "luan:Io.luan"
local Thread = require "luan:Thread.luan"
local Table = require "luan:Table.luan"
local java_to_table_deep = Table.java_to_table_deep or error()
local unpack = Table.unpack or error()
local ThreadLuan = require "java:luan.modules.ThreadLuan"
local Logging = require "luan:logging/Logging.luan"
local logger = Logging.logger "Rpc"


local Rpc = {}

Rpc.port = 9101

Rpc.cipher_suites = {
	"TLS_DH_anon_WITH_AES_128_GCM_SHA256"
	"TLS_DH_anon_WITH_AES_128_CBC_SHA256"
	"TLS_ECDH_anon_WITH_AES_128_CBC_SHA"
	"TLS_DH_anon_WITH_AES_128_CBC_SHA"
	"TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA"
	"SSL_DH_anon_WITH_3DES_EDE_CBC_SHA"
	"TLS_ECDH_anon_WITH_RC4_128_SHA"
	"SSL_DH_anon_WITH_RC4_128_MD5"
	"SSL_DH_anon_WITH_DES_CBC_SHA"
	"SSL_DH_anon_EXPORT_WITH_DES40_CBC_SHA"
	"SSL_DH_anon_EXPORT_WITH_RC4_40_MD5"
}

local function java_args(list)
	for i,v in ipairs(list) do
		list[i] = LuanJava.toJava(v)
	end
	return unpack(list)
end

local function luan_args(list,binary_in)
	list = java_to_table_deep(list)
	if binary_in ~= nil then
		local i_in = list[#list]
		list[#list] = nil
		local type = list[i_in]
		if type == "binary" then
			list[i_in] = JavaUtils.readAll(binary_in)
		elseif type == "input" then
			list[i_in] = Boot.new_LuanIn( IoLuan.LuanInput.new(binary_in) )
		else
			error(type)
		end
	end
	return unpack(list)
end

local function encode_binary(args)
	local binary_in, len_in, i_in
	for i,v in ipairs(args) do
		if type(v) == "binary" then
			binary_in==nil or error "can't have multiple binary args"
			i_in = i
			binary_in = ByteArrayInputStream.new(v)
			len_in = #v
			args[i] = "binary"
		elseif type(v) == "table" and v.java ~= nil and v.java.instanceof(IoLuan.LuanFile) then
			binary_in==nil or error "can't have multiple binary args"
			i_in = i
			binary_in = v.java.inputStream()
			len_in = v.length()
			args[i] = "input"
		end
	end
	args[#args+1] = i_in
	return binary_in, len_in
end

local function rpc_caller(socket)
	local client = RpcClient.new(socket)
	return function(fn_name,...)
		local args = {...}
		local binary_in, len_in = encode_binary(args)
		local call
		if binary_in == nil then
			call = RpcCall.new(fn_name,java_args(args))
		else
			call = RpcCall.new(binary_in,len_in,fn_name,java_args(args))
		end
		client.write(call)
		if fn_name == "close" then
			client.close()
			return
		end
		return try {
			function()
				local result = client.read()
				return luan_args(result.returnValues,result["in"])
			end
			catch = function(e)
				local cause = e.java.getCause()
				if cause ~= nil and cause.instanceof(RpcException) and cause.getMessage() == "luan" then
					error(cause.values.get(0))
				else
					e.throw()
				end
			end
		}
	end_function
end_function

Rpc.functions = {}

local function rpc_responder(socket,fns)
	local server = RpcServer.new(socket)
	local responder = {}
	function responder.is_closed()
		return server.isClosed()
	end_function
	function responder.respond()
		local call = server.read()
		if call==nil then
			return
		end
		local cmd = call.cmd
		if cmd == "close" then
			server.close()
			return
		end_if
		local fn = fns[cmd]
		if fn == nil then
			server.write(JavaRpc.COMMAND_NOT_FOUND)
			return
		end_if
		local rtn = try {
			function()
				return {fn(luan_args(call.args,call["in"]))}
			end
			catch = function(e)
				logger.warn(e)
				local ex = RpcException.new("luan",e.get_message())
				server.write(ex)
				return nil
			end
		}
		if rtn==nil then return end
		local binary_in, len_in = encode_binary(rtn)
		local result
		if binary_in == nil then
			result = RpcResult.new(java_args(rtn))
		else
			result = RpcResult.new(binary_in,len_in,java_args(rtn))
		end
		server.write(result)
	end
	return responder
end_function

function Rpc.remote(domain)
	local socket
	if Rpc.cipher_suites == nil then
		socket = Socket.new(domain,Rpc.port)
	else
		socket = IoUtils.getSSLSocketFactory().createSocket(domain,Rpc.port)
		socket.setEnabledCipherSuites(Rpc.cipher_suites)
	end
	local call = rpc_caller(socket)
	local mt = {}
	function mt.__index(_,key)
		return function(...)
			return call(key,...)
		end
	end
	local err = Luan.new_error("not closed")
	function mt.__gc(_)
		socket.isClosed() or logger.error(err)
	end
	local t = {}
	set_metatable(t,mt)
	return t
end_function

function Rpc.serve(port,fns)
	port = port or Rpc.port
	fns = fns or Rpc.functions
	local socket_server
	if Rpc.cipher_suites == nil then
		socket_server = ServerSocket.new(port)
	else
		socket_server = IoUtils.getSSLServerSocketFactory().createServerSocket(port)
		socket_server.setEnabledCipherSuites(Rpc.cipher_suites)
	end
	while true do
		try {
			function()
				local socket = socket_server.accept()
				local function server()
					local responder = nil
					try {
						function()
							responder = rpc_responder(socket,fns)
							while not responder.is_closed() do
								responder.respond()
							end
						end
						catch = function(e)
							logger.warn(e)
						end
						finally = function()
							responder and responder.after_close and responder.after_close()
						end
					}
				end
				Thread.fork(server)
			end
			catch = function(e)
				logger.error(e)
			end
		}
	end_while
end_function

return Rpc