Lua 中逐元素比较

我正在尝试使用标准的 < 运算符在 Lua 中进行逐元素比较。例如,以下是我想做的事情:

a = {5, 7, 10}
b = {6, 4, 15}
c = a < b -- 应返回 {true, false, true}

我已经有了实现加法的代码(以及减法、乘法等)。我的问题是,Lua 强制将比较的结果转换为布尔值。我不想要布尔值,我想要比较结果作为表返回。

目前这是我的代码,加法可以工作,但小于运算不工作:

m = {}
m['__add'] = function (a, b)
    -- 将两个表相加
    -- 可以正常工作
    c = {}
    for i = 1, #a do
        c[i] = a[i] + b[i]
    end
    return c
end
m['__lt'] = function (a, b)
    -- 应在每个元素上执行小于运算符
    -- 不工作,Lua 强制转换结果为布尔值
    c = {}
    for i = 1, #a do
        c[i] = a[i] < b[i]
    end
    return c
end

a = {5, 7, 10}
b = {6, 4, 15}

setmetatable(a, m)

c = a + b -- 期望结果为 {11, 11, 25}
print(c[1], c[2], c[3]) -- 工作得很好!

c = a < b -- 期望结果为 {true, false, true}
print(c[1], c[2], c[3]) -- 错误,Lua 将 c 转换为布尔值

Lua 编程手册中提到,__lt 元方法调用的结果总是被转换为布尔值。我的问题是,我该如何解决这个问题?我听说 Lua 对 DSL 很好,并且我真的需要在这里可以工作的语法。我认为使用 MetaLua 应该是可能的,但我不太确定该从哪里开始。

一位同事建议我只需使用 __shl 元方法和<<。我尝试过它,它可以工作,但我真的想使用小于运算符<,而不是使用错误符号的方法。

谢谢!

点赞
用户107090
用户107090

Lua 中的比较操作返回的是一个布尔值。

除非改动 Lua 的核心代码,否则你无法对此进行修改。

2016-04-16 02:00:24
用户3735873
用户3735873

正如其他人已经提到的,这个问题没有简单而直接的解决方案。但是借助于一个类似于 Python 的通用 zip() 函数,就可以简化问题,如下所示:

--------------------------------------------------------------------------------
-- Python-like zip() iterator
--------------------------------------------------------------------------------

function zip(...)
  local arrays, ans = {...}, {}
  local index = 0
  return
    function()
      index = index + 1
      for i,t in ipairs(arrays) do
        if type(t) == 'function' then ans[i] = t() else ans[i] = t[index] end
        if ans[i] == nil then return end
      end
      return table.unpack(ans)
    end
end

--------------------------------------------------------------------------------

a = {5, 7, 10}
b = {6, 4, 15}
c = {}

for a,b in zip(a,b) do
  c[#c+1] = a < b -- 应该返回 {true, false, true}
end

-- 显示答案
for _,v in ipairs(c) do print(v) end
2016-04-16 13:25:12
用户1847592
用户1847592

你能不能忍受下面这个有点啰嗦的 v() 记法:

v(a < b) 而不是 a < b

local vec_mt = {}

local operations = {
   copy     = function (a, b) return a     end,
   lt       = function (a, b) return a < b end,
   add      = function (a, b) return a + b end,
   tostring = tostring,
}

local function create_vector_instance(operand1, operation, operand2)
   local func, vec = operations[operation], {}
   for k, elem1 in ipairs(operand1) do
      local elem2 = operand2 and operand2[k]
      vec[k] = func(elem1, elem2)
   end
   return setmetatable(vec, vec_mt)
end

local saved_result

function v(...)  -- 用于创建 "vector" 类的构造函数
   local result = ...
   local tp = type(result)
   if tp == 'boolean' and saved_result then
      result, saved_result = saved_result
   elseif tp ~= 'table' then
      result = create_vector_instance({...}, 'copy')
   end
   return result
end

function vec_mt.__add(v1, v2)
   return create_vector_instance(v1, 'add', v2)
end

function vec_mt.__lt(v1, v2)
   saved_result = create_vector_instance(v1, 'lt', v2)
end

function vec_mt.__tostring(vec)
   return
      'Vector ('
      ..table.concat(create_vector_instance(vec, 'tostring'), ', ')
      ..')'
end

用法:

a = v(5, 7, 10); print(a)
b = v(6, 4, 15); print(b)

c =   a + b ; print(c)  -- 结果是 v(11, 11, 25)
c = v(a + b); print(c)  -- 结果是 v(11, 11, 25)
c = v(a < b); print(c)  -- 结果是 v(true, false, true)
2016-04-16 19:38:07
用户6217536
用户6217536

你只有两个选择可以让你的语法正确:

选择 1:修改 Lua 核心代码。

这可能会非常困难,并且在将来会很难维护。最大的问题是,Lua 在非常低的级别上假设比较运算符 <>==~= 返回一个布尔值。

Lua 生成的字节码实际上在任何比较操作上都会跳转。例如,类似 c = 4 < 5 的语句被编译成更像是 if (4 < 5) then c = true else c = false end 这样的字节码。

你可以通过 luac -l file.lua 来查看字节码的样子。如果你将 c=4<5 的字节码和 c=4+5 的字节码进行比较,你会明白我在说什么。加法的代码更短、更简单。Lua 假设你会使用比较运算符进行分支,而不是赋值。

选择 2:解析你的代码,修改它,然后运行。

这是我认为你应该做的。虽然很难,但大部分工作已经为你完成了(使用类似 LuaMinify 这样的东西)。

首先,编写一个你可以用来比较任何东西的函数。这里的想法是,如果是一个表格就执行你的特殊比较,否则对其他类型使用 <

my_less = function(a, b)
   if (type(a) == 'table') then
     c = {}
     for i = 1, #a do
       c[i] = a[i] < b[i]
     end
     return c
    else
      return a < b
    end
end

现在我们只需要将每个小于运算符 a<b 替换为 my_less(a,b)

让我们使用 LuaMinify 中的解析器。我们将使用以下代码进行调用:

local parse = require('ParseLua').ParseLua
local ident = require('FormatIdentity')

local code = "c=a*b<c+d"
local ret, ast = parse(code)
local _, f = ident(ast)
print(f)

这样做的唯一作用就是将代码解析为语法树,然后再次输出。我们将更改 FormatIdentity.lua 以使其进行替换。将第 138 行附近的部分替换为以下代码:

    elseif expr.AstType == 'BinopExpr' then --line 138
        if (expr.Op == '<') then
            tok_it = tok_it + 1
            out:appendStr('my_less(')
            formatExpr(expr.Lhs)
            out:appendStr(',')
            formatExpr(expr.Rhs)
            out:appendStr(')')
        else
            formatExpr(expr.Lhs)
            appendStr( expr.Op )
            formatExpr(expr.Rhs)
        end

就是这样了。它将把诸如 c=a*b<c+d 这样的代码替换为 my_less(a*b,c+d)。只需在运行时将所有代码都通过即可。

2016-04-18 01:07:09