Hey,
So I’m having some problems with the Zama Concrete ML linear regression model. I tried to build a minimal failing example here:
from sklearn.linear_model import LinearRegression as SklearnLinearRegression
from concrete.ml.sklearn import LinearRegression as ConcreteLinearRegression
import numpy as np
def generate_dataset(n, t):
x_train = np.array([np.float64(i + 1) for i in range(n)]).reshape(n, 1)
y_train = np.random.default_rng().uniform(0.0, 100.0, n)
x_test = np.array([np.float64(i + n + 1) for i in range(t)]).reshape(n, 1)
return x_train, y_train, x_test
x_train, y_train, x_test = generate_dataset(2, 2)
print("X train: ", " ".join([str(value[0]) for value in x_train]))
print("Y train: ", " ".join([str(value) for value in y_train]))
print("X test: ", " ".join([str(value[0]) for value in x_test]))
plaintext_model = SklearnLinearRegression()
plaintext_model.fit(x_train, y_train)
concrete_model = ConcreteLinearRegression(n_bits=16)
concrete_model.fit(x_train, y_train)
y_pred_plaintext = plaintext_model.predict(x_test)
additional_values = np.array([0.,1.,2.,3.,4.,5.,6.,10.,12.]).reshape(-1,1)
input_range = np.concatenate((x_test, additional_values), axis=0)
concrete_model.compile(input_range, verbose=True)
y_pred_concrete = concrete_model.predict(x_test, fhe="execute")
print(
"Plaintext coef and intercept: ",
plaintext_model.coef_[0],
plaintext_model.intercept_,
)
print(
"Concrete coef and intercept: ", concrete_model.coef_[0], concrete_model.intercept_
)
print("Plaintext predictions: ", " ".join([str(value) for value in y_pred_plaintext]))
print("Concrete predictions: ", " ".join([str(value[0]) for value in y_pred_concrete]))
So here we generate just two (X,Y) points in the training dataset so we can build a perfect linear regression that fits the two points at X=1.0 and 2.0. We then train to linear regression models – a regular sklearn one and the one from Concrete ML; then run inference on X=3.0 and 4.0.
The model from sklearn works as expected. The one from Concrete ML just repeats the the last known Y, and I can’t figure out why.
I thought initially, the issue is with not enough points in X_calibrate when I compile the circuit, but adding more (via additional_values) doesn’t seem to fix the issue. Different n_bits values have no effect either. Moreover, the issue persists not only with fhe="execute" but also when it is set to simulate and even disable.
I’m out of ideas now, banging my head
Any help would be greatly appreciated!
Here’s the output of the program above:
X train: 1.0 2.0
Y train: 12.691858014601243 60.26780285759041
X test: 3.0 4.0
Computation Graph
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = q_X # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%1 = [[1]] # ClearTensor<uint1, shape=(1, 1)> ∈ [1, 1]
%2 = matmul(%0, %1) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%3 = sum(%0, axis=1, keepdims=True) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%4 = 0 # ClearScalar<uint1> ∈ [0, 0]
%5 = multiply(%4, %3) # EncryptedTensor<uint1, shape=(1, 1)> ∈ [0, 0]
%6 = subtract(%2, %5) # EncryptedTensor<int16, shape=(1, 1)> ∈ [-32768, 32767]
%7 = [[-146355]] # ClearTensor<int19, shape=(1, 1)> ∈ [-146355, -146355]
%8 = add(%6, %7) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-179123, -113588]
return %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Constraints
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0:
%0 >= 16
%1:
%1 >= 1
%2:
%2 >= 16
%0 == %1
%1 == %2
%3:
%3 >= 16
%0 == %3
%4:
%4 >= 1
%5:
%5 >= 1
%4 == %3
%3 == %5
%6:
%6 >= 16
%2 == %5
%5 == %6
%7:
%7 >= 19
%8:
%8 >= 19
%6 == %7
%7 == %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Assignments
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = 19
%1 = 19
%2 = 19
%3 = 19
%4 = 19
%5 = 19
%6 = 19
%7 = 19
%8 = 19
max = 19
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bit-Width Assigned Computation Graph
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
%0 = q_X # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%1 = [[1]] # ClearTensor<uint20, shape=(1, 1)> ∈ [1, 1]
%2 = matmul(%0, %1) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%3 = sum(%0, axis=1, keepdims=True) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%4 = 0 # ClearScalar<uint20> ∈ [0, 0]
%5 = multiply(%4, %3) # EncryptedTensor<uint19, shape=(1, 1)> ∈ [0, 0]
%6 = subtract(%2, %5) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-32768, 32767]
%7 = [[-146355]] # ClearTensor<int20, shape=(1, 1)> ∈ [-146355, -146355]
%8 = add(%6, %7) # EncryptedTensor<int19, shape=(1, 1)> ∈ [-179123, -113588]
return %8
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Optimizer
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
### Optimizer display
--- Circuit
19 bits integers
0 manp (maxi log2 norm2)
--- User config
9.094947e-13 error per pbs call
1.000000e+00 error per circuit call
-- Solution correctness
For each pbs call: 1/2147483647, p_error (4.272044e-13)
For the full circuit: 1/2147483647 global_p_error(4.272044e-13)
--- Complexity for the full circuit
1.000000e+00 Millions Operations
-- Circuit Solution
CircuitSolution {
circuit_keys: CircuitKeys {
secret_keys: [
SecretLweKey {
identifier: 0,
polynomial_size: 1,
glwe_dimension: 1009,
description: "big representation",
},
],
keyswitch_keys: [],
bootstrap_keys: [],
conversion_keyswitch_keys: [],
circuit_bootstrap_keys: [],
private_functional_packing_keys: [],
},
instructions_keys: [],
crt_decomposition: [],
complexity: 1009.0,
p_error: 4.2720437586522667e-13,
global_p_error: 4.2720437586522667e-13,
is_feasible: true,
error_msg: "",
}###
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Statistics
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
size_of_secret_keys: 8072
size_of_bootstrap_keys: 0
size_of_keyswitch_keys: 0
size_of_inputs: 8080
size_of_outputs: 8080
p_error: 4.2720437586522667e-13
global_p_error: 4.2720437586522667e-13
complexity: 1009.0
programmable_bootstrap_count: 0
key_switch_count: 0
packing_key_switch_count: 0
clear_addition_count: 1
clear_addition_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 1
}
encrypted_addition_count: 2
encrypted_addition_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 2
}
clear_multiplication_count: 0
encrypted_negation_count: 1
encrypted_negation_count_per_parameter: {
LweSecretKeyParam(dimension=1009): 1
}
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Plaintext coef and intercept: 47.57594484298915 -34.88408682838791
Concrete coef and intercept: 47.57594484298915 -34.88408682838791
Plaintext predictions: 107.84374770057956 155.41969254356871
Concrete predictions: 60.26794520447507 60.26794520447507