Python转Lua的翻译错误,我的错误在哪里?

在询问和删除了一些问题之后,我不知道代码中的错误在哪里。期望的输出为:

[0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.31500873015873027, 0.12156230158730162, 0.5246873015873018, 0.5989928571428574, 0.060103968253968264]

但我的Lua代码只生成了:

3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328

我真的不知道我犯了什么错误。也许你们能够发现它。我尝试了一些更改,但它始终没有任何不同。我也不明白为什么每一行都有相同的数字。

可以工作的Python代码:

from math import prod
from fractions import Fraction
def bitstrings(n) :
    """返回所有长度为n的可能比特串"""
    if n == 0 :
        yield []
        return
    else :
        for b in [0,1] :
            for x in bitstrings(n-1) :
                yield [b] + x

def prob_selected(weights, num_selected = 5) :

    # P(n generated, including e)*P(e of n selected | n generated including e)
    # i.e. Sum_n (n generated, including e) * #num_selections / #generated
    # num_selected = how many will be drawn out of the hat (at most)

    n = len(weights)
    final_probability = [0] * n

    for bits in bitstrings(n) :
        num_generated = sum(bits)
        prob_generated = prod([w if b else (1-w) for (w,b) in zip(weights, bits)])

        for i in range(n) :
            if bits[i] :
                final_probability[i] += prob_generated * min(num_selected, num_generated) / num_generated
    return final_probability

print(prob_selected([1, 1, 1, 1, 1,
                     0.5, 0.2, 0.8, 0.9, 0.1]))

我的Lua代码:

-- Python的len()
function tablelength(T)
  local count = 0
  for _ in pairs(T) do count = count + 1 end
  return count
end

-- Python的sum()
table.reduce = function (list, fn)
    local acc
    for k, v in ipairs(list) do
        if 1 == k then
            acc = v
        else
            acc = fn(acc, v)
        end
    end
    return acc
end

globalArr = {}
function generateBitstrings (n, arr, i)
    if i == n then
        table.insert(globalArr, {table.unpack(arr)})
        return
    end

    arr[i] = 0
    generateBitstrings(n, arr, i + 1)

    arr[i] = 1
    generateBitstrings(n, arr, i + 1)
end

function prob_selected (weights, num_selected)
    local n = tablelength(weights)
    final_probability = {}

    for i=1, n do
        final_probability[i] = 0
    end

    globalArr = {}
    generateBitstrings(n + 1, {}, 1)
    for ibots, bits in ipairs(globalArr) do
        num_generated = table.reduce(
            bits,
            function(a, b)
                return a + b
            end
        )

        prob_generated = 1
        bitsLength = tablelength(bits)
        for i=1,bitsLength do
            if bits[i] then
                prob_generated = prob_generated * weights[i]
            else
                prob_generated = prob_generated * 1 - weights[i]
            end
        end

        for i=1,n do
            if bits[i] == 1 then
                final_probability[i] = final_probability[i] + (prob_generated * math.min(num_selected, num_generated) / num_generated)
            end
        end
    end
    return final_probability
end

for i, value in ipairs(prob_selected({1, 1, 1, 1, 1,0.5, 0.2, 0.8, 0.9, 0.1}, 5)) do
    print(value)
end
点赞
用户6632736
用户6632736
  1. 如@Egor Skriptunoff所说,在Lua中if bits [i]将不起作用,如你所期望的那样:“0”不是假的;只有falsenil才是,你需要if bits [i] == 1
  2. 你忘记在 prob_generated = prob_generated * (1-weights [i]) 中加上括号。
  3. 我添加了几个 local,虽然这不是关键。
  4. 我把 globalArr 改成了local;你可能想要重命名它。
  5. 我使测试输出更紧凑。
-- python len()
local function tablelength(T)
  local count = 0
  for _ in pairs(T) do count = count + 1 end
  return count
end

-- python sum()
table.reduce = function (list, fn)
    local acc
    for k, v in ipairs(list) do
        if 1 == k then
            acc = v
        else
            acc = fn(acc, v)
        end
    end
    return acc
end

local function generateBitstrings (global_arr, n, arr, i)
    if i == n then
        table.insert(global_arr, {table.unpack(arr)})
        return
    end

    arr[i] = 0
    generateBitstrings(global_arr, n, arr, i + 1)

    arr[i] = 1
    generateBitstrings(global_arr, n, arr, i + 1)
end

local function prob_selected (weights, num_selected)
    local n = tablelength(weights)
    local final_probability = {}

    for i=1, n do
        final_probability[i] = 0
    end

    local globalArr = {}
    generateBitstrings(globalArr, n + 1, {}, 1)
    for ibots, bits in ipairs(globalArr) do
        local num_generated = table.reduce(
            bits,
            function(a, b)
                return a + b
            end
        )

        local prob_generated = 1
        local bitsLength = tablelength(bits)
        for i=1,bitsLength do
            if bits[i] == 1 then
                prob_generated = prob_generated * weights[i]
            else
                prob_generated = prob_generated * (1 - weights[i])
            end
        end

        for i=1,n do
            if bits[i] == 1 then
                final_probability[i] = final_probability[i] + (prob_generated * math.min(num_selected, num_generated) / num_generated)
            end
        end
    end
    return final_probability
end

print (table.concat (prob_selected({1, 1, 1, 1, 1,0.5, 0.2, 0.8, 0.9, 0.1}, 5), ', '))
2020-10-26 17:46:16