Lua函数泛化n维数组

以下的MWE初始化、递增和打印2-D和3-D数值数组。它很容易扩展到处理4-D等数组,但是更优雅的做法是创建一个接受维数(最多7维)为参数或根据输入参数数量确定维数的通用函数。如何实现一个好的方式?

MWE:

local i_low, i_high = 2, 4
local i_range = i_high - i_low - 1

local j_low, j_high = 3, 7
local j_range = j_high - j_low - 1

local k_low, k_high = 1, 3
local k_range = k_high - k_low - 1

local myArray_two = {}
local myArray_three = {}

function initArray_two(t, i_low, i_high, j_low, j_high, value)
    local i, j = 0, 0

    for i = i_low, i_high, 1 do
        for j = j_low, j_high, 1 do
            local idx = j * i_range + i
            t[idx] = value
        end
    end
end

function initArray_three(t, i_low, i_high, j_low, j_high, k_low, k_high, value)
    local i, j, k = 0, 0, 0

    for i = i_low, i_high, 1 do
        for j = j_low, j_high, 1 do
            for k = k_low, k_high, 1 do
                local idx = k*j_range*i_range + j*i_range + i
                t[idx] = value
            end
        end
    end
end

function incrValue_two(t, i, j, value)
    assert(i>i_low and i <= i_high)
    assert(j>j_low and j <= j_high)

    local idx = j * i_range + i
    t[idx] = t[idx] + value
end

function incrValue_three(t, i, j, k, value)
    assert(i>i_low and i <= i_high)
    assert(j>j_low and j <= j_high)
    assert(k>k_low and k <= k_high)

    local idx = k*j_range*i_range + j*i_range + i
    t[idx] = t[idx] + value
end

function printArray_two(t, title, i_low, i_high, j_low, j_high)
    local i, j = 0, 0
    print(title.."\n")

    for i = i_low, i_high, 1 do
        for j = j_low, j_high, 1 do
            local idx = j * i_range + i
            print(i.."\t"..j.."\t"..t[idx])
        end
        print("\n")
    end

end

function printArray_three(t, title, i_low, i_high, j_low, j_high, k_low, k_high)
    local i, j, k = 0, 0, 0
    print(title.."\n")

    for i = i_low, i_high, 1 do
        for j = j_low, j_high, 1 do
            for k = k_low, k_high, 1 do
                local idx = k*j_range*i_range + j*i_range + i
                print(i.."\t"..j.."\t"..k.."\t"..t[idx].."\n")
            end
        end
    end

end

initArray_two(myArray_two, i_low, i_high, j_low, j_high, 0)
initArray_three(myArray_three, i_low, i_high, j_low, j_high, k_low, k_high, 1)

incrValue_two(myArray_two, 2, 3, 11)
incrValue_two(myArray_two, 2, 3, 13)
incrValue_two(myArray_two, 4, 7, 5)
printArray_two(myArray_two, "一个二维数组", i_low, i_high, j_low, j_high)

incrValue_three(myArray_three, 2, 3, 1, 9)
incrValue_three(myArray_three, 2, 3, 1, 17)
printArray_three(myArray_three, "一个三维数组", i_low, i_high, j_low, j_high, k_low, k_high)
点赞
用户7396148
用户7396148

你可以使用函数定义中的 ... 来泛化函数。这将捕获你重复的参数,理想情况下是你的范围。

function initArray(t, default_value, range, ...)
  local args = {...} -- our next ranges
  if args[1] then    -- if we have more ranges recurse
    for i = range[1], range[2], 1 do
      t[i] = initArray({}, default_value, table.unpack(args))
    end
  else               -- if we dont have more ranges set default values
    for i = range[1], range[2], 1 do
      t[i] = default_value
    end
  end
  return t
end

我还将你的调用更改为创建高和低范围值的对:

initArray(myArray_two, 0, {i_low, i_high}, {j_low, j_high})

以下是整个代码:

local i_low, i_high = 2, 4
local j_low, j_high = 3, 7
local k_low, k_high = 1, 3

local myArray_two = {}
local myArray_three = {}

function initArray(t, default_value, range, ...)
  local args = {...}
  if args[1] then
    for i = range[1], range[2], 1 do
      t[i] = initArray({}, default_value, table.unpack(args))
    end
  else
    for i = range[1], range[2], 1 do
      t[i] = default_value
    end
  end
  return t
end

function incrValue(t, value, ...)
  local args = {...}
  assert(#args >= 1)

  local el = t
  for _,v in ipairs(args) do
    if type(el[v]) == 'table' then
      el = el[v]
    else
      el[v] = el[v] + value
    end
  end
end

function printArray(t, title)
  print(title .. "\n")

  for k, v in pairs(t) do
    if type(v) == 'table' then
      recurPrintArray(k , v)
    else
      print(k .. "\t" .. v)
    end
    print("\n")
  end
end

function recurPrintArray(s, t)
  for k, v in pairs(t) do
    if type(v) == 'table' then
      recurPrintArray(s .. "\t" .. k, v)
    else
      print(s .. "\t" .. k .. "\t" .. v)
    end
  end

end

initArray(myArray_two, 0, {i_low, i_high}, {j_low, j_high})
initArray(myArray_three, 1, {i_low, i_high}, {j_low, j_high}, {k_low, k_high})

incrValue(myArray_two, 11, 2, 3)
incrValue(myArray_two, 13, 2, 3)
incrValue(myArray_two, 5, 4, 7)
printArray(myArray_two, "A 2-D Array", i_low, i_high, j_low, j_high)

incrValue(myArray_three, 9, 2, 3, 1)
incrValue(myArray_three, 17, 2, 3, 1)
printArray(myArray_three, "A 3-D Array", i_low, i_high, j_low, j_high, k_low, k_high)
2019-12-30 21:17:38